diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b476e762..6da318e6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module): def forward( self, hidden_states, - encoder_hidden_states=None, - timestep=None, attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, cross_attention_kwargs=None, class_labels=None, ): @@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module): norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index dfb14617..d4fd3045 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase): # make sure that more than 4 GB is allocated mem_bytes = torch.cuda.max_memory_allocated() - assert mem_bytes > 4e9 + assert mem_bytes > 5e9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2 def test_stable_diffusion_fp16_vs_autocast(self):