Refactor cross attention and allow mechanism to tweak cross attention function (#1639)

* first proposal

* rename

* up

* Apply suggestions from code review

* better

* up

* finish

* up

* rename

* correct versatile

* up

* up

* up

* up

* fix

* Apply suggestions from code review

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add error message

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-12-20 18:49:05 +01:00 committed by GitHub
parent a9190badf7
commit 4125756e88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 660 additions and 324 deletions

View File

@ -24,6 +24,7 @@ from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput from ..utils import BaseOutput
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention
@dataclass @dataclass
@ -175,7 +176,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
""" """
Args: Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
@ -213,7 +221,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks # 2. Blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
)
# 3. Output # 3. Output
if self.is_input_continuous: if self.is_input_continuous:
@ -287,6 +300,20 @@ class AttentionBlock(nn.Module):
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if not is_xformers_available(): if not is_xformers_available():
@ -312,20 +339,6 @@ class AttentionBlock(nn.Module):
raise e raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states): def forward(self, hidden_states):
residual = hidden_states residual = hidden_states
batch, channel, height, width = hidden_states.shape batch, channel, height, width = hidden_states.shape
@ -423,7 +436,8 @@ class BasicTransformerBlock(nn.Module):
bias=attention_bias, bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None, cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) # is a self-attention )
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
# 2. Cross-Attn # 2. Cross-Attn
@ -450,58 +464,39 @@ class BasicTransformerBlock(nn.Module):
# 3. Feed-forward # 3. Feed-forward
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def forward(
if use_memory_efficient_attention_xformers: self,
if not is_xformers_available(): hidden_states,
print("Here is how to install it") encoder_hidden_states=None,
raise ModuleNotFoundError( timestep=None,
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" attention_mask=None,
" xformers", cross_attention_kwargs=None,
name="xformers", ):
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
# 1. Self-Attention # 1. Self-Attention
norm_hidden_states = ( norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
) )
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
if self.only_cross_attention: attn_output = self.attn1(
hidden_states = ( norm_hidden_states,
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
) attention_mask=attention_mask,
else: **cross_attention_kwargs,
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states )
hidden_states = attn_output + hidden_states
if self.attn2 is not None: if self.attn2 is not None:
# 2. Cross-Attention # 2. Cross-Attention
norm_hidden_states = ( norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
) )
hidden_states = ( attn_output = self.attn2(
self.attn2( norm_hidden_states,
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask encoder_hidden_states=encoder_hidden_states,
) attention_mask=attention_mask,
+ hidden_states **cross_attention_kwargs,
) )
hidden_states = attn_output + hidden_states
# 3. Feed-forward # 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@ -509,229 +504,6 @@ class BasicTransformerBlock(nn.Module):
return hidden_states return hidden_states
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
class FeedForward(nn.Module): class FeedForward(nn.Module):
r""" r"""
A feed-forward layer. A feed-forward layer.

View File

@ -0,0 +1,428 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..utils.import_utils import is_xformers_available
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
processor = processor if processor is not None else CrossAttnProcessor()
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
)
elif not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
processor = XFormersCrossAttnProcessor()
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = CrossAttnAddedKVProcessor()
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `CrossAttention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def head_to_batch_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def get_attention_scores(self, query, key, attention_mask=None):
dtype = query.dtype
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
attention_probs = attention_probs.to(dtype)
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length):
head_size = self.heads
if attention_mask is None:
return attention_mask
if attention_mask.shape[-1] != target_length:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class XFormersCrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(hidden_states.shape[0] // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
)
for i in range(hidden_states.shape[0] // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
AttnProcessor = Union[
CrossAttnProcessor,
XFormersCrossAttnProcessor,
SlicedAttnProcessor,
CrossAttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
]

View File

@ -15,7 +15,8 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@ -481,11 +482,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
@ -544,6 +550,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
) )
) )
resnets.append( resnets.append(
@ -564,19 +571,19 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2), encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
**cross_attention_kwargs,
) )
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
# resnet # resnet
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
@ -750,7 +757,9 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
@ -771,10 +780,15 @@ class CrossAttnDownBlock2D(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
cross_attention_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
output_states += (hidden_states,) output_states += (hidden_states,)
@ -1310,6 +1324,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
@ -1338,23 +1353,23 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
output_states = () output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# resnet # resnet
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
# attn # attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2), encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
**cross_attention_kwargs,
) )
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
output_states += (hidden_states,) output_states += (hidden_states,)
@ -1531,6 +1546,7 @@ class CrossAttnUpBlock2D(nn.Module):
res_hidden_states_tuple, res_hidden_states_tuple,
temb=None, temb=None,
encoder_hidden_states=None, encoder_hidden_states=None,
cross_attention_kwargs=None,
upsample_size=None, upsample_size=None,
attention_mask=None, attention_mask=None,
): ):
@ -1557,10 +1573,15 @@ class CrossAttnUpBlock2D(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
cross_attention_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
@ -2113,6 +2134,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
) )
) )
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
@ -2149,7 +2171,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
encoder_hidden_states=None, encoder_hidden_states=None,
upsample_size=None, upsample_size=None,
attention_mask=None, attention_mask=None,
cross_attention_kwargs=None,
): ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
for resnet, attn in zip(self.resnets, self.attentions): for resnet, attn in zip(self.resnets, self.attentions):
# resnet # resnet
# pop res hidden states # pop res hidden states
@ -2160,15 +2184,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
# attn # attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2), encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
**cross_attention_kwargs,
) )
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -21,6 +21,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .unet_2d_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
@ -265,6 +266,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor):
# set recursively
def fn_recursive_attn_processor(module: torch.nn.Module):
if hasattr(module, "set_processor"):
module.set_processor(processor)
for child in module.children():
fn_recursive_attn_processor(child)
for module in self.children():
fn_recursive_attn_processor(module)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
@ -341,6 +354,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
@ -426,6 +440,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
@ -434,7 +449,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 4. mid # 4. mid
sample = self.mid_block( sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
) )
# 5. up # 5. up
@ -455,6 +474,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
) )

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -7,8 +7,8 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin from ...modeling_utils import ModelMixin
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
from ...models.embeddings import TimestepEmbedding, Timesteps from ...models.embeddings import TimestepEmbedding, Timesteps
from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging from ...utils import logging
@ -351,6 +351,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU() self.conv_act = nn.SiLU()
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor):
# set recursively
def fn_recursive_attn_processor(module: torch.nn.Module):
if hasattr(module, "set_processor"):
module.set_processor(processor)
for child in module.children():
fn_recursive_attn_processor(child)
for module in self.children():
fn_recursive_attn_processor(module)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
@ -427,6 +439,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[UNet2DConditionOutput, Tuple]:
r""" r"""
@ -512,6 +525,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
@ -520,7 +534,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 4. mid # 4. mid
sample = self.mid_block( sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
) )
# 5. up # 5. up
@ -541,6 +559,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
@ -840,7 +859,9 @@ class CrossAttnDownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
# TODO(Patrick, William) - attention mask is not used # TODO(Patrick, William) - attention mask is not used
output_states = () output_states = ()
@ -861,10 +882,15 @@ class CrossAttnDownBlockFlat(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
cross_attention_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
output_states += (hidden_states,) output_states += (hidden_states,)
@ -1042,6 +1068,7 @@ class CrossAttnUpBlockFlat(nn.Module):
res_hidden_states_tuple, res_hidden_states_tuple,
temb=None, temb=None,
encoder_hidden_states=None, encoder_hidden_states=None,
cross_attention_kwargs=None,
upsample_size=None, upsample_size=None,
attention_mask=None, attention_mask=None,
): ):
@ -1068,10 +1095,15 @@ class CrossAttnUpBlockFlat(nn.Module):
create_custom_forward(attn, return_dict=False), create_custom_forward(attn, return_dict=False),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
cross_attention_kwargs,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
@ -1166,18 +1198,23 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
).sample
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module): class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
@ -1230,6 +1267,7 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
processor=CrossAttnAddedKVProcessor(),
) )
) )
resnets.append( resnets.append(
@ -1250,19 +1288,19 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): def forward(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2), encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
**cross_attention_kwargs,
) )
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
# resnet # resnet
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)

View File

@ -391,6 +391,63 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
for module in model.children(): for module in model.children():
check_slicable_dim_attr(module) check_slicable_dim_attr(module)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(num))
self.is_run = False
self.number = 0
self.counter = 0
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states += self.weight
self.is_run = True
self.counter += 1
self.number = number
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
processor = AttnEasyProc(5.0)
model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
assert processor.counter == 12
assert processor.is_run
assert processor.number == 123
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel