Refactor cross attention and allow mechanism to tweak cross attention function (#1639)
* first proposal * rename * up * Apply suggestions from code review * better * up * finish * up * rename * correct versatile * up * up * up * up * fix * Apply suggestions from code review * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * add error message Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
a9190badf7
commit
4125756e88
|
@ -24,6 +24,7 @@ from ..modeling_utils import ModelMixin
|
|||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from .cross_attention import CrossAttention
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -175,7 +176,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
|
@ -213,7 +221,12 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
|
@ -287,6 +300,20 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
|
@ -312,20 +339,6 @@ class AttentionBlock(nn.Module):
|
|||
raise e
|
||||
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
@ -423,7 +436,8 @@ class BasicTransformerBlock(nn.Module):
|
|||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is a self-attention
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
|
||||
# 2. Cross-Attn
|
||||
|
@ -450,58 +464,39 @@ class BasicTransformerBlock(nn.Module):
|
|||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
if self.attn2 is not None:
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
|
||||
if self.only_cross_attention:
|
||||
hidden_states = (
|
||||
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = (
|
||||
self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
+ hidden_states
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
@ -509,229 +504,6 @@ class BasicTransformerBlock(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
self._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
attention_probs = attention_probs.to(value.dtype)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
|
||||
if self.upcast_attention:
|
||||
query_slice = query_slice.float()
|
||||
key_slice = key_slice.float()
|
||||
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
||||
query_slice,
|
||||
key_slice.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
||||
|
||||
if self.upcast_softmax:
|
||||
attn_slice = attn_slice.float()
|
||||
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
attn_slice = attn_slice.to(value.dtype)
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
||||
# TODO attention_mask
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
|
|
@ -0,0 +1,428 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
processor = processor if processor is not None else CrossAttnProcessor()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
||||
# which uses this type of cross attention ONLY because the attention mask of format
|
||||
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
||||
raise NotImplementedError(
|
||||
"Memory efficient attention with `xformers` is currently not supported when"
|
||||
" `self.added_kv_proj_dim` is defined."
|
||||
)
|
||||
elif not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
processor = XFormersCrossAttnProcessor()
|
||||
else:
|
||||
processor = CrossAttnProcessor()
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
if slice_size is not None and self.added_kv_proj_dim is not None:
|
||||
processor = SlicedAttnAddedKVProcessor(slice_size)
|
||||
elif slice_size is not None:
|
||||
processor = SlicedAttnProcessor(slice_size)
|
||||
elif self.added_kv_proj_dim is not None:
|
||||
processor = CrossAttnAddedKVProcessor()
|
||||
else:
|
||||
processor = CrossAttnProcessor()
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor"):
|
||||
self.processor = processor
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
||||
# The `CrossAttention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
def batch_to_head_dim(self, tensor):
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor):
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(self, query, key, attention_mask=None):
|
||||
dtype = query.dtype
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
attention_probs = attention_probs.to(dtype)
|
||||
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(self, attention_mask, target_length):
|
||||
head_size = self.heads
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
if attention_mask.shape[-1] != target_length:
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class CrossAttnProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttnAddedKVProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XFormersCrossAttnProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnAddedKVProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
AttnProcessor = Union[
|
||||
CrossAttnProcessor,
|
||||
XFormersCrossAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
CrossAttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
]
|
|
@ -15,7 +15,8 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
|
||||
from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
|
@ -481,11 +482,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
@ -544,6 +550,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
processor=CrossAttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
|
@ -564,19 +571,19 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
@ -750,7 +757,9 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
|
@ -771,10 +780,15 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
|
@ -1310,6 +1324,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
processor=CrossAttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
@ -1338,23 +1353,23 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
output_states = ()
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
|
@ -1531,6 +1546,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
cross_attention_kwargs=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
@ -1557,10 +1573,15 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
@ -2113,6 +2134,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
processor=CrossAttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
|
@ -2149,7 +2171,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
# pop res hidden states
|
||||
|
@ -2160,15 +2184,12 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -21,6 +21,7 @@ import torch.utils.checkpoint
|
|||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .cross_attention import AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
|
@ -265,6 +266,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attn_processor(self, processor: AttnProcessor):
|
||||
# set recursively
|
||||
def fn_recursive_attn_processor(module: torch.nn.Module):
|
||||
if hasattr(module, "set_processor"):
|
||||
module.set_processor(processor)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_attn_processor(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_attn_processor(module)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
@ -341,6 +354,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
|
@ -426,6 +440,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
@ -434,7 +449,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
|
||||
# 4. mid
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 5. up
|
||||
|
@ -455,6 +474,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -7,8 +7,8 @@ import torch.nn as nn
|
|||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps
|
||||
from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
@ -351,6 +351,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attn_processor(self, processor: AttnProcessor):
|
||||
# set recursively
|
||||
def fn_recursive_attn_processor(module: torch.nn.Module):
|
||||
if hasattr(module, "set_processor"):
|
||||
module.set_processor(processor)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_attn_processor(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_attn_processor(module)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
@ -427,6 +439,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
|
@ -512,6 +525,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
@ -520,7 +534,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
|
||||
# 4. mid
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 5. up
|
||||
|
@ -541,6 +559,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
@ -840,7 +859,9 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
|
@ -861,10 +882,15 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
|
@ -1042,6 +1068,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
cross_attention_kwargs=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
@ -1068,10 +1095,15 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
@ -1166,18 +1198,23 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
|
@ -1230,6 +1267,7 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
|||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
processor=CrossAttnAddedKVProcessor(),
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
|
@ -1250,19 +1288,19 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
|
|
@ -391,6 +391,63 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
for module in model.children():
|
||||
check_slicable_dim_attr(module)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(num))
|
||||
self.is_run = False
|
||||
self.number = 0
|
||||
self.counter = 0
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states += self.weight
|
||||
|
||||
self.is_run = True
|
||||
self.counter += 1
|
||||
self.number = number
|
||||
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
processor = AttnEasyProc(5.0)
|
||||
|
||||
model.set_attn_processor(processor)
|
||||
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
|
||||
|
||||
assert processor.counter == 12
|
||||
assert processor.is_run
|
||||
assert processor.number == 123
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
|
Loading…
Reference in New Issue