[attention] Fix attention (#2656)
* [attention] Fix attention * fix * correct
This commit is contained in:
parent
fa7a576191
commit
4ae54b3789
|
@ -271,9 +271,10 @@ class BasicTransformerBlock(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=None,
|
|
||||||
timestep=None,
|
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
timestep=None,
|
||||||
cross_attention_kwargs=None,
|
cross_attention_kwargs=None,
|
||||||
class_labels=None,
|
class_labels=None,
|
||||||
):
|
):
|
||||||
|
@ -302,12 +303,14 @@ class BasicTransformerBlock(nn.Module):
|
||||||
norm_hidden_states = (
|
norm_hidden_states = (
|
||||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||||
)
|
)
|
||||||
|
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
|
||||||
|
# prepare attention mask here
|
||||||
|
|
||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
attn_output = self.attn2(
|
attn_output = self.attn2(
|
||||||
norm_hidden_states,
|
norm_hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
**cross_attention_kwargs,
|
**cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
|
@ -737,7 +737,7 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||||
|
|
||||||
# make sure that more than 4 GB is allocated
|
# make sure that more than 4 GB is allocated
|
||||||
mem_bytes = torch.cuda.max_memory_allocated()
|
mem_bytes = torch.cuda.max_memory_allocated()
|
||||||
assert mem_bytes > 4e9
|
assert mem_bytes > 5e9
|
||||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
|
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_fp16_vs_autocast(self):
|
def test_stable_diffusion_fp16_vs_autocast(self):
|
||||||
|
|
Loading…
Reference in New Issue