From 8b45096927bd5aaaaf6dd5bfff1cbd480852f769 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 14 Sep 2022 19:31:24 +0530 Subject: [PATCH] [CrossAttention] add different method for sliced attention (#446) * add different method for sliced attention * Update src/diffusers/models/attention.py * Apply suggestions from code review * Update src/diffusers/models/attention.py Co-authored-by: Patrick von Platen --- src/diffusers/models/attention.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index accddacd..55062c32 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -262,11 +262,24 @@ class CrossAttention(nn.Module): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(query, key, value, sequence_length, dim) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) return self.to_out(hidden_states) - def _attention(self, query, key, value, sequence_length, dim): + def _attention(self, query, key, value): + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype