diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index accddacd..55062c32 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -262,11 +262,24 @@ class CrossAttention(nn.Module): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(query, key, value, sequence_length, dim) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) return self.to_out(hidden_states) - def _attention(self, query, key, value, sequence_length, dim): + def _attention(self, query, key, value): + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype