From cf4227cd1e1a361aaf26109f2e970aa9abb620b7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 15 Mar 2023 18:04:05 +0100 Subject: [PATCH] T5Attention support for cross-attention (#2654) * fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen --- src/diffusers/models/cross_attention.py | 56 ++++++++++++------- tests/models/test_models_unet_2d_condition.py | 3 +- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 9f994064..a0ecfb0f 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -59,6 +59,8 @@ class CrossAttention(nn.Module): cross_attention_norm: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -68,7 +70,7 @@ class CrossAttention(nn.Module): self.upcast_softmax = upcast_softmax self.cross_attention_norm = cross_attention_norm - self.scale = dim_head**-0.5 + self.scale = dim_head**-0.5 if scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -95,14 +97,17 @@ class CrossAttention(nn.Module): self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor - # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: - processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor() + ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( @@ -295,7 +300,9 @@ class CrossAttnProcessor: encoder_hidden_states=None, attention_mask=None, ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -362,7 +369,9 @@ class LoRACrossAttnProcessor(nn.Module): def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -435,7 +444,9 @@ class XFormersCrossAttnProcessor: self.attention_op = attention_op def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -454,7 +465,7 @@ class XFormersCrossAttnProcessor: value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) @@ -472,7 +483,10 @@ class AttnProcessor2_0: raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, inner_dim = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -496,6 +510,7 @@ class AttnProcessor2_0: value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) @@ -527,7 +542,9 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -542,7 +559,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module): value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = attn.batch_to_head_dim(hidden_states) @@ -559,8 +576,9 @@ class SlicedAttnProcessor: self.slice_size = slice_size def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -577,12 +595,12 @@ class SlicedAttnProcessor: key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) - batch_size_attention = query.shape[0] + batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - for i in range(hidden_states.shape[0] // self.slice_size): + for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size @@ -638,12 +656,12 @@ class SlicedAttnAddedKVProcessor: key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - batch_size_attention = query.shape[0] + batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype + (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - for i in range(hidden_states.shape[0] // self.slice_size): + for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index c1f3bc05..e313fcfb 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -118,7 +118,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): model.enable_xformers_memory_efficient_attention() assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers + model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ + == "XFormersCrossAttnProcessor" ), "xformers is not enabled" @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")