Fix slow tests (#689)

* revert using baddbmm in attention
- to fix `test_stable_diffusion_memory_chunking` test

* styling
This commit is contained in:
Nouamane Tazi 2022-09-30 17:45:02 +01:00 committed by GitHub
parent 552b967020
commit b2cfc7a04c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 7 deletions

View File

@ -274,13 +274,8 @@ class CrossAttention(nn.Module):
return self.to_out(hidden_states)
def _attention(self, query, key, value):
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)