[attention] Fix attention (#2656)

* [attention] Fix attention

* fix

* correct
This commit is contained in:
Patrick von Platen 2023-03-13 19:10:33 +01:00 committed by GitHub
parent fa7a576191
commit 4ae54b3789
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None, attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None, cross_attention_kwargs=None,
class_labels=None, class_labels=None,
): ):
@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
) )
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
# prepare attention mask here
# 2. Cross-Attention # 2. Cross-Attention
attn_output = self.attn2( attn_output = self.attn2(
norm_hidden_states, norm_hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=encoder_attention_mask,
**cross_attention_kwargs, **cross_attention_kwargs,
) )
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states

View File

@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
# make sure that more than 4 GB is allocated # make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9 assert mem_bytes > 5e9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
def test_stable_diffusion_fp16_vs_autocast(self): def test_stable_diffusion_fp16_vs_autocast(self):