Fix slow tests (#689)
* revert using baddbmm in attention - to fix `test_stable_diffusion_memory_chunking` test * styling
This commit is contained in:
parent
552b967020
commit
b2cfc7a04c
|
@ -274,13 +274,8 @@ class CrossAttention(nn.Module):
|
||||||
return self.to_out(hidden_states)
|
return self.to_out(hidden_states)
|
||||||
|
|
||||||
def _attention(self, query, key, value):
|
def _attention(self, query, key, value):
|
||||||
attention_scores = torch.baddbmm(
|
# TODO: use baddbmm for better performance
|
||||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||||
query,
|
|
||||||
key.transpose(-1, -2),
|
|
||||||
beta=0,
|
|
||||||
alpha=self.scale,
|
|
||||||
)
|
|
||||||
attention_probs = attention_scores.softmax(dim=-1)
|
attention_probs = attention_scores.softmax(dim=-1)
|
||||||
# compute attention output
|
# compute attention output
|
||||||
hidden_states = torch.matmul(attention_probs, value)
|
hidden_states = torch.matmul(attention_probs, value)
|
||||||
|
|
Loading…
Reference in New Issue