T5Attention support for cross-attention (#2654)
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
9d1341d69b
commit
cf4227cd1e
|
@ -59,6 +59,8 @@ class CrossAttention(nn.Module):
|
|||
cross_attention_norm: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -68,7 +70,7 @@ class CrossAttention(nn.Module):
|
|||
self.upcast_softmax = upcast_softmax
|
||||
self.cross_attention_norm = cross_attention_norm
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.scale = dim_head**-0.5 if scale_qk else 1.0
|
||||
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
|
@ -95,14 +97,17 @@ class CrossAttention(nn.Module):
|
|||
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
if processor is None:
|
||||
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
|
||||
processor = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor()
|
||||
)
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
|
@ -295,7 +300,9 @@ class CrossAttnProcessor:
|
|||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
|
@ -362,7 +369,9 @@ class LoRACrossAttnProcessor(nn.Module):
|
|||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
|
@ -435,7 +444,9 @@ class XFormersCrossAttnProcessor:
|
|||
self.attention_op = attention_op
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
|
@ -454,7 +465,7 @@ class XFormersCrossAttnProcessor:
|
|||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
@ -472,7 +483,10 @@ class AttnProcessor2_0:
|
|||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, inner_dim = hidden_states.shape
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
inner_dim = hidden_states.shape[-1]
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
@ -496,6 +510,7 @@ class AttnProcessor2_0:
|
|||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
@ -527,7 +542,9 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
|
|||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
|
@ -542,7 +559,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
|
|||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
|
@ -559,8 +576,9 @@ class SlicedAttnProcessor:
|
|||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
@ -577,12 +595,12 @@ class SlicedAttnProcessor:
|
|||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention = query.shape[0]
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
|
@ -638,12 +656,12 @@ class SlicedAttnAddedKVProcessor:
|
|||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
batch_size_attention = query.shape[0]
|
||||
batch_size_attention, query_tokens, _ = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||
for i in range(batch_size_attention // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
|
|
|
@ -118,7 +118,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersCrossAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
|
|
Loading…
Reference in New Issue