T5Attention support for cross-attention (#2654)

* fix AttnProcessor2_0

Fix use of AttnProcessor2_0 for cross attention with mask

* added scale_qk and out_bias flags

* fixed for xformers

* check if it has scale argument

* Update cross_attention.py

* check torch version

* fix sliced attn

* style

* set scale

* fix test

* fixed addedKV processor

* revert back AttnProcessor2_0

* if missing if

* fix inner_dim

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Kashif Rasul 2023-03-15 18:04:05 +01:00 committed by GitHub
parent 9d1341d69b
commit cf4227cd1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 20 deletions

View File

@ -59,6 +59,8 @@ class CrossAttention(nn.Module):
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
@ -68,7 +70,7 @@ class CrossAttention(nn.Module):
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5
self.scale = dim_head**-0.5 if scale_qk else 1.0
self.heads = heads
# for slice_size > 0 the attention score computation
@ -95,14 +97,17 @@ class CrossAttention(nn.Module):
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor()
)
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(
@ -295,7 +300,9 @@ class CrossAttnProcessor:
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
@ -362,7 +369,9 @@ class LoRACrossAttnProcessor(nn.Module):
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
@ -435,7 +444,9 @@ class XFormersCrossAttnProcessor:
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@ -454,7 +465,7 @@ class XFormersCrossAttnProcessor:
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
@ -472,7 +483,10 @@ class AttnProcessor2_0:
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, inner_dim = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@ -496,6 +510,7 @@ class AttnProcessor2_0:
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@ -527,7 +542,9 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
@ -542,7 +559,7 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = attn.batch_to_head_dim(hidden_states)
@ -559,8 +576,9 @@ class SlicedAttnProcessor:
self.slice_size = slice_size
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
@ -577,12 +595,12 @@ class SlicedAttnProcessor:
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
@ -638,12 +656,12 @@ class SlicedAttnAddedKVProcessor:
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

View File

@ -118,7 +118,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersCrossAttnProcessor"
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")