[CrossAttention] add different method for sliced attention (#446)
* add different method for sliced attention * Update src/diffusers/models/attention.py * Apply suggestions from code review * Update src/diffusers/models/attention.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
1a69c6ff0e
commit
8b45096927
|
@ -262,11 +262,24 @@ class CrossAttention(nn.Module):
|
||||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
hidden_states = self._attention(query, key, value, sequence_length, dim)
|
|
||||||
|
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||||
|
hidden_states = self._attention(query, key, value)
|
||||||
|
else:
|
||||||
|
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||||
|
|
||||||
return self.to_out(hidden_states)
|
return self.to_out(hidden_states)
|
||||||
|
|
||||||
def _attention(self, query, key, value, sequence_length, dim):
|
def _attention(self, query, key, value):
|
||||||
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||||
|
attention_probs = attention_scores.softmax(dim=-1)
|
||||||
|
# compute attention output
|
||||||
|
hidden_states = torch.matmul(attention_probs, value)
|
||||||
|
# reshape hidden_states
|
||||||
|
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
||||||
batch_size_attention = query.shape[0]
|
batch_size_attention = query.shape[0]
|
||||||
hidden_states = torch.zeros(
|
hidden_states = torch.zeros(
|
||||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||||
|
|
Loading…
Reference in New Issue