parent
4e2674934f
commit
94566e6dd8
|
@ -1,6 +1,8 @@
|
|||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
@ -43,18 +45,16 @@ class AttentionBlock(nn.Module):
|
|||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_head_channels=None,
|
||||
num_groups=32,
|
||||
use_checkpoint=False,
|
||||
encoder_channels=None,
|
||||
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
|
||||
overwrite_qkv=False,
|
||||
overwrite_linear=False,
|
||||
rescale_output_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
if num_head_channels is None:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
|
@ -62,7 +62,6 @@ class AttentionBlock(nn.Module):
|
|||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
|
||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
self.n_heads = self.num_heads
|
||||
|
@ -160,115 +159,135 @@ class AttentionBlock(nn.Module):
|
|||
return result
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
# class AttnBlockpp(nn.Module):
|
||||
# """Channel-wise self-attention block. Modified from DDPM."""
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# channels,
|
||||
# skip_rescale=False,
|
||||
# init_scale=0.0,
|
||||
# num_heads=1,
|
||||
# num_head_channels=-1,
|
||||
# use_checkpoint=False,
|
||||
# encoder_channels=None,
|
||||
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
|
||||
# overwrite_qkv=False,
|
||||
# overwrite_from_grad_tts=False,
|
||||
# ):
|
||||
# super().__init__()
|
||||
# num_groups = min(channels // 4, 32)
|
||||
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
||||
# self.NIN_0 = NIN(channels, channels)
|
||||
# self.NIN_1 = NIN(channels, channels)
|
||||
# self.NIN_2 = NIN(channels, channels)
|
||||
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
||||
# self.skip_rescale = skip_rescale
|
||||
#
|
||||
# self.channels = channels
|
||||
# if num_head_channels == -1:
|
||||
# self.num_heads = num_heads
|
||||
# else:
|
||||
# assert (
|
||||
# channels % num_head_channels == 0
|
||||
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
# self.num_heads = channels // num_head_channels
|
||||
#
|
||||
# self.use_checkpoint = use_checkpoint
|
||||
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
||||
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
# self.n_heads = self.num_heads
|
||||
#
|
||||
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
#
|
||||
# self.is_weight_set = False
|
||||
#
|
||||
# def set_weights(self):
|
||||
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
|
||||
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
|
||||
#
|
||||
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
# self.proj_out.bias.data = self.NIN_3.b.data
|
||||
#
|
||||
# self.norm.weight.data = self.GroupNorm_0.weight.data
|
||||
# self.norm.bias.data = self.GroupNorm_0.bias.data
|
||||
#
|
||||
# def forward(self, x):
|
||||
# if not self.is_weight_set:
|
||||
# self.set_weights()
|
||||
# self.is_weight_set = True
|
||||
#
|
||||
# B, C, H, W = x.shape
|
||||
# h = self.GroupNorm_0(x)
|
||||
# q = self.NIN_0(h)
|
||||
# k = self.NIN_1(h)
|
||||
# v = self.NIN_2(h)
|
||||
#
|
||||
# w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
|
||||
# w = torch.reshape(w, (B, H, W, H * W))
|
||||
# w = F.softmax(w, dim=-1)
|
||||
# w = torch.reshape(w, (B, H, W, H, W))
|
||||
# h = torch.einsum("bhwij,bcij->bchw", w, v)
|
||||
# h = self.NIN_3(h)
|
||||
#
|
||||
# if not self.skip_rescale:
|
||||
# result = x + h
|
||||
# else:
|
||||
# result = (x + h) / np.sqrt(2.0)
|
||||
#
|
||||
# result = self.forward_2(x)
|
||||
#
|
||||
# return result
|
||||
#
|
||||
# def forward_2(self, x, encoder_out=None):
|
||||
# b, c, *spatial = x.shape
|
||||
# hid_states = self.norm(x).view(b, c, -1)
|
||||
#
|
||||
# qkv = self.qkv(hid_states)
|
||||
# bs, width, length = qkv.shape
|
||||
# assert width % (3 * self.n_heads) == 0
|
||||
# ch = width // (3 * self.n_heads)
|
||||
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
#
|
||||
# if encoder_out is not None:
|
||||
# encoder_kv = self.encoder_kv(encoder_out)
|
||||
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
# k = torch.cat([ek, k], dim=-1)
|
||||
# v = torch.cat([ev, v], dim=-1)
|
||||
#
|
||||
# scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
#
|
||||
# a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
# h = a.reshape(bs, -1, length)
|
||||
#
|
||||
# h = self.proj_out(h)
|
||||
# h = h.reshape(b, c, *spatial)
|
||||
#
|
||||
# return (x + h) / np.sqrt(2.0)
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = x.shape
|
||||
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q = self.reshape_heads_to_batch_dim(q)
|
||||
k = self.reshape_heads_to_batch_dim(k)
|
||||
v = self.reshape_heads_to_batch_dim(v)
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = self.reshape_batch_dim_to_heads(out)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# TODO(Patrick) - this can and should be removed
|
||||
|
@ -287,3 +306,24 @@ class NIN(nn.Module):
|
|||
super().__init__()
|
||||
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
|
|
@ -23,6 +23,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
@ -105,13 +106,8 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True
|
||||
)
|
||||
|
||||
# upsampling
|
||||
|
@ -171,10 +167,10 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h = self.mid(hs[-1], temb)
|
||||
# h = self.mid.block_1(h, temb)
|
||||
# h = self.mid.attn_1(h)
|
||||
# h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
|
|
|
@ -7,6 +7,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
|
@ -193,7 +194,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
|
@ -226,6 +226,20 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="silu",
|
||||
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
|
||||
attn_num_heads=num_heads,
|
||||
attn_num_head_channels=num_head_channels,
|
||||
attn_encoder_channels=transformer_dim,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
|
@ -238,7 +252,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
|
@ -253,6 +266,10 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
overwrite_for_glide=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnet_1 = self.middle_block[0]
|
||||
self.mid.attn = self.middle_block[1]
|
||||
self.mid.resnet_2 = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
|
@ -276,7 +293,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=num_head_channels,
|
||||
encoder_channels=transformer_dim,
|
||||
|
@ -343,7 +359,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
|||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
h = self.mid(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
@ -438,7 +454,7 @@ class GlideTextToImageUNetModel(GlideUNetModel):
|
|||
for module in self.input_blocks:
|
||||
h = module(h, emb, transformer_out)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, transformer_out)
|
||||
h = self.mid(h, emb, transformer_out)
|
||||
for module in self.output_blocks:
|
||||
other = hs.pop()
|
||||
h = torch.cat([h, other], dim=1)
|
||||
|
|
|
@ -133,6 +133,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
|
||||
# self.mid = UNetMidBlock2D
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
|
|
|
@ -11,6 +11,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
# from .resnet import ResBlock
|
||||
|
@ -239,14 +240,12 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
conv_resample=conv_resample,
|
||||
dims=dims,
|
||||
num_classes=num_classes,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_fp16=use_fp16,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
num_heads_upsample=num_heads_upsample,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
use_spatial_transformer=use_spatial_transformer,
|
||||
transformer_depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
|
@ -283,7 +282,6 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype_ = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
|
@ -333,10 +331,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
|
@ -366,6 +362,25 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
|
||||
if dim_head < 0:
|
||||
dim_head = None
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=ch,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="silu",
|
||||
resnet_time_scale_shift="scale_shift" if use_scale_shift_norm else "default",
|
||||
attention_layer_type="self" if not use_spatial_transformer else "spatial",
|
||||
attn_num_heads=num_heads,
|
||||
attn_num_head_channels=dim_head,
|
||||
attn_depth=transformer_depth,
|
||||
attn_encoder_channels=context_dim,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete after weight conversion
|
||||
# init to be able to overwrite `self.mid`
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResnetBlock2D(
|
||||
in_channels=ch,
|
||||
|
@ -378,10 +393,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||
|
@ -395,6 +408,10 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnet_1 = self.middle_block[0]
|
||||
self.mid.attn = self.middle_block[1]
|
||||
self.mid.resnet_2 = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
|
@ -425,10 +442,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(
|
||||
|
@ -493,7 +508,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
h = self.mid(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .resnet import ResnetBlock2D
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
attention_layer_type: str = "self",
|
||||
attn_num_heads=1,
|
||||
attn_num_head_channels=None,
|
||||
attn_encoder_channels=None,
|
||||
attn_dim_head=None,
|
||||
attn_depth=None,
|
||||
output_scale_factor=1.0,
|
||||
overwrite_qkv=False,
|
||||
overwrite_unet=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.resnet_1 = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
)
|
||||
|
||||
if attention_layer_type == "self":
|
||||
self.attn = AttentionBlock(
|
||||
in_channels,
|
||||
num_heads=attn_num_heads,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
encoder_channels=attn_encoder_channels,
|
||||
overwrite_qkv=overwrite_qkv,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
)
|
||||
elif attention_layer_type == "spatial":
|
||||
self.attn = (
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
attn_num_heads,
|
||||
attn_num_head_channels,
|
||||
depth=attn_depth,
|
||||
context_dim=attn_encoder_channels,
|
||||
),
|
||||
)
|
||||
|
||||
self.resnet_2 = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete all of the following code
|
||||
self.is_overwritten = False
|
||||
self.overwrite_unet = overwrite_unet
|
||||
if self.overwrite_unet:
|
||||
block_in = in_channels
|
||||
self.temb_ch = temb_channels
|
||||
self.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
self.attn_1 = AttentionBlock(
|
||||
block_in,
|
||||
num_heads=attn_num_heads,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
encoder_channels=attn_encoder_channels,
|
||||
overwrite_qkv=True,
|
||||
)
|
||||
self.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
||||
if not self.is_overwritten and self.overwrite_unet:
|
||||
self.resnet_1 = self.block_1
|
||||
self.attn = self.attn_1
|
||||
self.resnet_2 = self.block_2
|
||||
self.is_overwritten = True
|
||||
|
||||
hidden_states = self.resnet_1(hidden_states, temb)
|
||||
|
||||
if encoder_states is None:
|
||||
hidden_states = self.attn(hidden_states)
|
||||
else:
|
||||
hidden_states = self.attn(hidden_states, encoder_states)
|
||||
|
||||
hidden_states = self.resnet_2(hidden_states, temb)
|
||||
return hidden_states
|
|
@ -27,6 +27,7 @@ from ..modeling_utils import ModelMixin
|
|||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
class Combine(nn.Module):
|
||||
|
@ -214,6 +215,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
|
||||
hs_c.append(in_ch)
|
||||
|
||||
# mid
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=math.sqrt(2.0),
|
||||
resnet_act_fn="silu",
|
||||
resnet_groups=min(in_ch // 4, 32),
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(
|
||||
ResnetBlock2D(
|
||||
|
@ -238,6 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
self.mid.resnet_1 = modules[len(modules) - 3]
|
||||
self.mid.attn = modules[len(modules) - 2]
|
||||
self.mid.resnet_2 = modules[len(modules) - 1]
|
||||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
|
@ -378,13 +392,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
|
||||
hs.append(h)
|
||||
|
||||
h = hs[-1]
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h)
|
||||
m_idx += 1
|
||||
h = modules[m_idx](h, temb)
|
||||
m_idx += 1
|
||||
# h = hs[-1]
|
||||
# h = modules[m_idx](h, temb)
|
||||
# m_idx += 1
|
||||
# h = modules[m_idx](h)
|
||||
# m_idx += 1
|
||||
# h = modules[m_idx](h, temb)
|
||||
# m_idx += 1
|
||||
|
||||
h = self.mid(h, temb)
|
||||
m_idx += 3
|
||||
|
||||
pyramid = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue