Merge pull request #57 from huggingface/big_clean_up
[Clean up] Clean up unused code
This commit is contained in:
commit
abedfb08f1
|
@ -34,48 +34,6 @@ def Normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
# class ResnetBlock(nn.Module):
|
|
||||||
# def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
|
||||||
# super().__init__()
|
|
||||||
# self.in_channels = in_channels
|
|
||||||
# out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
# self.out_channels = out_channels
|
|
||||||
# self.use_conv_shortcut = conv_shortcut
|
|
||||||
#
|
|
||||||
# self.norm1 = Normalize(in_channels)
|
|
||||||
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
# self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
|
||||||
# self.norm2 = Normalize(out_channels)
|
|
||||||
# self.dropout = torch.nn.Dropout(dropout)
|
|
||||||
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
# if self.in_channels != self.out_channels:
|
|
||||||
# if self.use_conv_shortcut:
|
|
||||||
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
# else:
|
|
||||||
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
#
|
|
||||||
# def forward(self, x, temb):
|
|
||||||
# h = x
|
|
||||||
# h = self.norm1(h)
|
|
||||||
# h = nonlinearity(h)
|
|
||||||
# h = self.conv1(h)
|
|
||||||
#
|
|
||||||
# h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
|
||||||
#
|
|
||||||
# h = self.norm2(h)
|
|
||||||
# h = nonlinearity(h)
|
|
||||||
# h = self.dropout(h)
|
|
||||||
# h = self.conv2(h)
|
|
||||||
#
|
|
||||||
# if self.in_channels != self.out_channels:
|
|
||||||
# if self.use_conv_shortcut:
|
|
||||||
# x = self.conv_shortcut(x)
|
|
||||||
# else:
|
|
||||||
# x = self.nin_shortcut(x)
|
|
||||||
#
|
|
||||||
# return x + h
|
|
||||||
|
|
||||||
|
|
||||||
class UNetModel(ModelMixin, ConfigMixin):
|
class UNetModel(ModelMixin, ConfigMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -29,19 +29,6 @@ def convert_module_to_f32(l):
|
||||||
l.bias.data = l.bias.data.float()
|
l.bias.data = l.bias.data.float()
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D average pooling module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.AvgPool1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.AvgPool2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.AvgPool3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
|
|
@ -78,182 +78,6 @@ def Normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
# class LinearAttention(nn.Module):
|
|
||||||
# def __init__(self, dim, heads=4, dim_head=32):
|
|
||||||
# super().__init__()
|
|
||||||
# self.heads = heads
|
|
||||||
# hidden_dim = dim_head * heads
|
|
||||||
# self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
|
||||||
# self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# b, c, h, w = x.shape
|
|
||||||
# qkv = self.to_qkv(x)
|
|
||||||
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# k = k.softmax(dim=-1)
|
|
||||||
# context = torch.einsum("bhdn,bhen->bhde", k, v)
|
|
||||||
# out = torch.einsum("bhde,bhdn->bhen", context, q)
|
|
||||||
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
|
|
||||||
# return self.to_out(out)
|
|
||||||
#
|
|
||||||
|
|
||||||
# class SpatialSelfAttention(nn.Module):
|
|
||||||
# def __init__(self, in_channels):
|
|
||||||
# super().__init__()
|
|
||||||
# self.in_channels = in_channels
|
|
||||||
#
|
|
||||||
# self.norm = Normalize(in_channels)
|
|
||||||
# self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
# self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
# self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
#
|
|
||||||
# def forward(self, x):
|
|
||||||
# h_ = x
|
|
||||||
# h_ = self.norm(h_)
|
|
||||||
# q = self.q(h_)
|
|
||||||
# k = self.k(h_)
|
|
||||||
# v = self.v(h_)
|
|
||||||
#
|
|
||||||
# compute attention
|
|
||||||
# b, c, h, w = q.shape
|
|
||||||
# q = rearrange(q, "b c h w -> b (h w) c")
|
|
||||||
# k = rearrange(k, "b c h w -> b c (h w)")
|
|
||||||
# w_ = torch.einsum("bij,bjk->bik", q, k)
|
|
||||||
#
|
|
||||||
# w_ = w_ * (int(c) ** (-0.5))
|
|
||||||
# w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
||||||
#
|
|
||||||
# attend to values
|
|
||||||
# v = rearrange(v, "b c h w -> b c (h w)")
|
|
||||||
# w_ = rearrange(w_, "b i j -> b j i")
|
|
||||||
# h_ = torch.einsum("bij,bjk->bik", v, w_)
|
|
||||||
# h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
|
||||||
# h_ = self.proj_out(h_)
|
|
||||||
#
|
|
||||||
# return x + h_
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.scale = dim_head**-0.5
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
|
||||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
|
||||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
|
||||||
|
|
||||||
def reshape_heads_to_batch_dim(self, tensor):
|
|
||||||
batch_size, seq_len, dim = tensor.shape
|
|
||||||
head_size = self.heads
|
|
||||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def reshape_batch_dim_to_heads(self, tensor):
|
|
||||||
batch_size, seq_len, dim = tensor.shape
|
|
||||||
head_size = self.heads
|
|
||||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
|
||||||
batch_size, sequence_length, dim = x.shape
|
|
||||||
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
q = self.reshape_heads_to_batch_dim(q)
|
|
||||||
k = self.reshape_heads_to_batch_dim(k)
|
|
||||||
v = self.reshape_heads_to_batch_dim(v)
|
|
||||||
|
|
||||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
mask = mask.reshape(batch_size, -1)
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
|
||||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
|
||||||
out = self.reshape_batch_dim_to_heads(out)
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
|
||||||
super().__init__()
|
|
||||||
self.attn1 = CrossAttention(
|
|
||||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is a self-attention
|
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
|
||||||
self.attn2 = CrossAttention(
|
|
||||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
|
||||||
) # is self-attn if context is none
|
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
|
||||||
self.checkpoint = checkpoint
|
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
|
||||||
x = self.attn1(self.norm1(x)) + x
|
|
||||||
x = self.attn2(self.norm2(x), context=context) + x
|
|
||||||
x = self.ff(self.norm3(x)) + x
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialTransformer(nn.Module):
|
|
||||||
"""
|
|
||||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
|
||||||
standard transformer action. Finally, reshape to image
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
inner_dim = n_heads * d_head
|
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
|
|
||||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
|
||||||
[
|
|
||||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
|
||||||
for d in range(depth)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
|
||||||
|
|
||||||
def forward(self, x, context=None):
|
|
||||||
# note: if no context is given, cross-attention defaults to self-attention
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
x_in = x
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.proj_in(x)
|
|
||||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
|
||||||
for block in self.transformer_blocks:
|
|
||||||
x = block(x, context=context)
|
|
||||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
|
||||||
x = self.proj_out(x)
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
def convert_module_to_f16(l):
|
def convert_module_to_f16(l):
|
||||||
"""
|
"""
|
||||||
Convert primitive modules to float16.
|
Convert primitive modules to float16.
|
||||||
|
@ -274,19 +98,6 @@ def convert_module_to_f32(l):
|
||||||
l.bias.data = l.bias.data.float()
|
l.bias.data = l.bias.data.float()
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D average pooling module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.AvgPool1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return nn.AvgPool2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.AvgPool3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
@ -330,36 +141,6 @@ def normalization(channels, swish=0.0):
|
||||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
|
||||||
"""
|
|
||||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spacial_dim: int,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads_channels: int,
|
|
||||||
output_dim: int = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
|
||||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
|
||||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
|
||||||
self.num_heads = embed_dim // num_heads_channels
|
|
||||||
self.attention = QKVAttention(self.num_heads)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
b, c, *_spatial = x.shape
|
|
||||||
x = x.reshape(b, c, -1) # NC(HW)
|
|
||||||
x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
|
||||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
|
||||||
x = self.qkv_proj(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
x = self.c_proj(x)
|
|
||||||
return x[:, :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
"""
|
"""
|
||||||
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
A sequential module that passes timestep embeddings to the children that support it as an extra input.
|
||||||
|
@ -376,39 +157,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
A module which performs QKV attention and splits in a different order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n_heads):
|
|
||||||
super().__init__()
|
|
||||||
self.n_heads = n_heads
|
|
||||||
|
|
||||||
def forward(self, qkv):
|
|
||||||
"""
|
|
||||||
Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x
|
|
||||||
T] tensor after attention.
|
|
||||||
"""
|
|
||||||
bs, width, length = qkv.shape
|
|
||||||
assert width % (3 * self.n_heads) == 0
|
|
||||||
ch = width // (3 * self.n_heads)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
|
||||||
weight = torch.einsum(
|
|
||||||
"bct,bcs->bts",
|
|
||||||
(q * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
(k * scale).view(bs * self.n_heads, ch, length),
|
|
||||||
) # More stable with f16 than dividing afterwards
|
|
||||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
||||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
|
||||||
return a.reshape(bs, -1, length)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_flops(model, _x, y):
|
|
||||||
return count_flops_attn(model, _x, y)
|
|
||||||
|
|
||||||
|
|
||||||
def count_flops_attn(model, _x, y):
|
def count_flops_attn(model, _x, y):
|
||||||
"""
|
"""
|
||||||
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
|
A counter for the `thop` package to count the operations in an attention operation. Meant to be used like:
|
||||||
|
@ -602,21 +350,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
# ResBlock(
|
Downsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op")
|
||||||
# ch,
|
|
||||||
# time_embed_dim,
|
|
||||||
# dropout,
|
|
||||||
# out_channels=out_ch,
|
|
||||||
# dims=dims,
|
|
||||||
# use_checkpoint=use_checkpoint,
|
|
||||||
# use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
# down=True,
|
|
||||||
# )
|
|
||||||
None
|
|
||||||
if resblock_updown
|
|
||||||
else Downsample(
|
|
||||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
|
@ -703,21 +437,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
if level and i == num_res_blocks:
|
if level and i == num_res_blocks:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch))
|
||||||
# ResBlock(
|
|
||||||
# ch,
|
|
||||||
# time_embed_dim,
|
|
||||||
# dropout,
|
|
||||||
# out_channels=out_ch,
|
|
||||||
# dims=dims,
|
|
||||||
# use_checkpoint=use_checkpoint,
|
|
||||||
# use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
# up=True,
|
|
||||||
# )
|
|
||||||
None
|
|
||||||
if resblock_updown
|
|
||||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
|
||||||
)
|
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
@ -784,215 +504,119 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
class EncoderUNetModel(nn.Module):
|
class SpatialTransformer(nn.Module):
|
||||||
"""
|
"""
|
||||||
The half UNet model with attention and timestep embedding. For usage, see UNet.
|
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||||
|
standard transformer action. Finally, reshape to image
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||||
self,
|
|
||||||
image_size,
|
|
||||||
in_channels,
|
|
||||||
model_channels,
|
|
||||||
out_channels,
|
|
||||||
num_res_blocks,
|
|
||||||
attention_resolutions,
|
|
||||||
dropout=0,
|
|
||||||
channel_mult=(1, 2, 4, 8),
|
|
||||||
conv_resample=True,
|
|
||||||
dims=2,
|
|
||||||
use_checkpoint=False,
|
|
||||||
use_fp16=False,
|
|
||||||
num_heads=1,
|
|
||||||
num_head_channels=-1,
|
|
||||||
num_heads_upsample=-1,
|
|
||||||
use_scale_shift_norm=False,
|
|
||||||
resblock_updown=False,
|
|
||||||
use_new_attention_order=False,
|
|
||||||
pool="adaptive",
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
|
||||||
num_heads_upsample = num_heads
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
inner_dim = n_heads * d_head
|
||||||
self.out_channels = out_channels
|
self.norm = Normalize(in_channels)
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attention_resolutions = attention_resolutions
|
|
||||||
self.dropout = dropout
|
|
||||||
self.channel_mult = channel_mult
|
|
||||||
self.conv_resample = conv_resample
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.num_head_channels = num_head_channels
|
|
||||||
self.num_heads_upsample = num_heads_upsample
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||||
self.time_embed = nn.Sequential(
|
|
||||||
linear(model_channels, time_embed_dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
linear(time_embed_dim, time_embed_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
|
[
|
||||||
)
|
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||||
self._feature_size = model_channels
|
for d in range(depth)
|
||||||
input_block_chans = [model_channels]
|
|
||||||
ch = model_channels
|
|
||||||
ds = 1
|
|
||||||
for level, mult in enumerate(channel_mult):
|
|
||||||
for _ in range(num_res_blocks):
|
|
||||||
layers = [
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=ch,
|
|
||||||
out_channels=model_channels * mult,
|
|
||||||
dropout=dropout,
|
|
||||||
temb_channels=time_embed_dim,
|
|
||||||
eps=1e-5,
|
|
||||||
non_linearity="silu",
|
|
||||||
overwrite_for_ldm=True,
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
|
||||||
if ds in attention_resolutions:
|
|
||||||
layers.append(
|
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=num_head_channels,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
|
||||||
self._feature_size += ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
if level != len(channel_mult) - 1:
|
|
||||||
out_ch = ch
|
|
||||||
self.input_blocks.append(
|
|
||||||
TimestepEmbedSequential(
|
|
||||||
# ResBlock(
|
|
||||||
# ch,
|
|
||||||
# time_embed_dim,
|
|
||||||
# dropout,
|
|
||||||
# out_channels=out_ch,
|
|
||||||
# dims=dims,
|
|
||||||
# use_checkpoint=use_checkpoint,
|
|
||||||
# use_scale_shift_norm=use_scale_shift_norm,
|
|
||||||
# down=True,
|
|
||||||
# )
|
|
||||||
None
|
|
||||||
if resblock_updown
|
|
||||||
else Downsample(
|
|
||||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ch = out_ch
|
|
||||||
input_block_chans.append(ch)
|
|
||||||
ds *= 2
|
|
||||||
self._feature_size += ch
|
|
||||||
|
|
||||||
self.middle_block = TimestepEmbedSequential(
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=ch,
|
|
||||||
out_channels=None,
|
|
||||||
dropout=dropout,
|
|
||||||
temb_channels=time_embed_dim,
|
|
||||||
eps=1e-5,
|
|
||||||
non_linearity="silu",
|
|
||||||
overwrite_for_ldm=True,
|
|
||||||
),
|
|
||||||
AttentionBlock(
|
|
||||||
ch,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_head_channels=num_head_channels,
|
|
||||||
use_new_attention_order=use_new_attention_order,
|
|
||||||
),
|
|
||||||
ResnetBlock(
|
|
||||||
in_channels=ch,
|
|
||||||
out_channels=None,
|
|
||||||
dropout=dropout,
|
|
||||||
temb_channels=time_embed_dim,
|
|
||||||
eps=1e-5,
|
|
||||||
non_linearity="silu",
|
|
||||||
overwrite_for_ldm=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._feature_size += ch
|
|
||||||
self.pool = pool
|
|
||||||
if pool == "adaptive":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
normalization(ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.AdaptiveAvgPool2d((1, 1)),
|
|
||||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
|
||||||
nn.Flatten(),
|
|
||||||
)
|
|
||||||
elif pool == "attention":
|
|
||||||
assert num_head_channels != -1
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
normalization(ch),
|
|
||||||
nn.SiLU(),
|
|
||||||
AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
|
|
||||||
)
|
|
||||||
elif pool == "spatial":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
nn.Linear(self._feature_size, 2048),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(2048, self.out_channels),
|
|
||||||
)
|
|
||||||
elif pool == "spatial_v2":
|
|
||||||
self.out = nn.Sequential(
|
|
||||||
nn.Linear(self._feature_size, 2048),
|
|
||||||
normalization(2048),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(2048, self.out_channels),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
|
||||||
|
|
||||||
def convert_to_fp16(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float16.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f16)
|
|
||||||
self.middle_block.apply(convert_module_to_f16)
|
|
||||||
|
|
||||||
def convert_to_fp32(self):
|
|
||||||
"""
|
|
||||||
Convert the torso of the model to float32.
|
|
||||||
"""
|
|
||||||
self.input_blocks.apply(convert_module_to_f32)
|
|
||||||
self.middle_block.apply(convert_module_to_f32)
|
|
||||||
|
|
||||||
def forward(self, x, timesteps):
|
|
||||||
"""
|
|
||||||
Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch
|
|
||||||
of timesteps. :return: an [N x K] Tensor of outputs.
|
|
||||||
"""
|
|
||||||
emb = self.time_embed(
|
|
||||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = []
|
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||||
h = x.type(self.dtype)
|
|
||||||
for module in self.input_blocks:
|
def forward(self, x, context=None):
|
||||||
h = module(h, emb)
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
if self.pool.startswith("spatial"):
|
b, c, h, w = x.shape
|
||||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
x_in = x
|
||||||
h = self.middle_block(h, emb)
|
x = self.norm(x)
|
||||||
if self.pool.startswith("spatial"):
|
x = self.proj_in(x)
|
||||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||||
h = torch.cat(results, axis=-1)
|
for block in self.transformer_blocks:
|
||||||
return self.out(h)
|
x = block(x, context=context)
|
||||||
else:
|
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||||
h = h.type(x.dtype)
|
x = self.proj_out(x)
|
||||||
return self.out(h)
|
return x + x_in
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn1 = CrossAttention(
|
||||||
|
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is a self-attention
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||||
|
self.attn2 = CrossAttention(
|
||||||
|
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||||
|
) # is self-attn if context is none
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
|
def forward(self, x, context=None):
|
||||||
|
x = self.attn1(self.norm1(x)) + x
|
||||||
|
x = self.attn2(self.norm2(x), context=context) + x
|
||||||
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.heads = heads
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||||
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
|
|
||||||
|
def reshape_heads_to_batch_dim(self, tensor):
|
||||||
|
batch_size, seq_len, dim = tensor.shape
|
||||||
|
head_size = self.heads
|
||||||
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def reshape_batch_dim_to_heads(self, tensor):
|
||||||
|
batch_size, seq_len, dim = tensor.shape
|
||||||
|
head_size = self.heads
|
||||||
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
batch_size, sequence_length, dim = x.shape
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
q = self.reshape_heads_to_batch_dim(q)
|
||||||
|
k = self.reshape_heads_to_batch_dim(k)
|
||||||
|
v = self.reshape_heads_to_batch_dim(v)
|
||||||
|
|
||||||
|
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = mask.reshape(batch_size, -1)
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||||
|
out = self.reshape_batch_dim_to_heads(out)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
Loading…
Reference in New Issue