diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b4e5f2e0..c2f27bd9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -274,13 +274,8 @@ class CrossAttention(nn.Module): return self.to_out(hidden_states) def _attention(self, query, key, value): - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value)