[attention] Fix attention (#2656)
* [attention] Fix attention * fix * correct
This commit is contained in:
parent
fa7a576191
commit
4ae54b3789
|
@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
cross_attention_kwargs=None,
|
||||
class_labels=None,
|
||||
):
|
||||
|
@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
|
|||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
|
||||
# prepare attention mask here
|
||||
|
||||
# 2. Cross-Attention
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
|
|
@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
|||
|
||||
# make sure that more than 4 GB is allocated
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes > 4e9
|
||||
assert mem_bytes > 5e9
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_fp16_vs_autocast(self):
|
||||
|
|
Loading…
Reference in New Issue