Merge pull request #46 from huggingface/merge_ldm_resnet
[ResNet Refactor] Merge ldm into resnet
This commit is contained in:
commit
b65eb377dd
|
@ -162,7 +162,7 @@ class Downsample(nn.Module):
|
|||
|
||||
# RESNETS
|
||||
|
||||
# unet_glide.py & unet_ldm.py
|
||||
# unet_glide.py
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
|
|||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
overwrite=False, # TODO(Patrick) - use for glide at later stage
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
|
|||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
self.overwrite = overwrite
|
||||
self.is_overwritten = False
|
||||
if self.overwrite:
|
||||
in_channels = channels
|
||||
out_channels = self.out_channels
|
||||
conv_shortcut = False
|
||||
dropout = 0.0
|
||||
temb_channels = emb_channels
|
||||
groups = 32
|
||||
pre_norm = True
|
||||
eps = 1e-5
|
||||
non_linearity = "silu"
|
||||
self.pre_norm = pre_norm
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if self.pre_norm:
|
||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nonlinearity
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def set_weights(self):
|
||||
# TODO(Patrick): use for glide at later stage
|
||||
self.norm1.weight.data = self.in_layers[0].weight.data
|
||||
self.norm1.bias.data = self.in_layers[0].bias.data
|
||||
|
||||
self.conv1.weight.data = self.in_layers[-1].weight.data
|
||||
self.conv1.bias.data = self.in_layers[-1].bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
|
||||
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
|
||||
|
||||
self.norm2.weight.data = self.out_layers[0].weight.data
|
||||
self.norm2.bias.data = self.out_layers[0].bias.data
|
||||
|
||||
self.conv2.weight.data = self.out_layers[-1].weight.data
|
||||
self.conv2.bias.data = self.out_layers[-1].bias.data
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
||||
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
|
|||
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
if self.overwrite:
|
||||
# TODO(Patrick): use for glide at later stage
|
||||
self.set_weights()
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
|
@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
|
|||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
|
@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
|
|||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
result = self.skip_connection(x) + h
|
||||
|
||||
# TODO(Patrick) Use for glide at later stage
|
||||
# result = self.forward_2(x, emb)
|
||||
|
||||
return result
|
||||
|
||||
def forward_2(self, x, temb, mask=1.0):
|
||||
if self.overwrite and not self.is_overwritten:
|
||||
self.set_weights()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
if self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
# unet.py and unet_grad_tts.py
|
||||
|
@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
|
|||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
overwrite_for_grad_tts=False,
|
||||
overwrite_for_ldm=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
|
@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
|
|||
self.nonlinearity = nonlinearity
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
# TODO(Patrick) - this branch is never used I think => can be deleted!
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.is_overwritten = False
|
||||
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||
self.overwrite_for_ldm = overwrite_for_ldm
|
||||
if self.overwrite_for_grad_tts:
|
||||
dim = in_channels
|
||||
dim_out = out_channels
|
||||
|
@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
|
|||
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
||||
else:
|
||||
self.res_conv = torch.nn.Identity()
|
||||
elif self.overwrite_for_ldm:
|
||||
dims = 2
|
||||
# eps = 1e-5
|
||||
# non_linearity = "silu"
|
||||
# overwrite_for_ldm
|
||||
channels = in_channels
|
||||
emb_channels = temb_channels
|
||||
use_scale_shift_norm = False
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
nn.Identity(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
|
||||
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.out_channels == in_channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
# elif use_conv:
|
||||
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
|
@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
|
|||
self.nin_shortcut.weight.data = self.res_conv.weight.data
|
||||
self.nin_shortcut.bias.data = self.res_conv.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=None):
|
||||
def set_weights_ldm(self):
|
||||
self.norm1.weight.data = self.in_layers[0].weight.data
|
||||
self.norm1.bias.data = self.in_layers[0].bias.data
|
||||
|
||||
self.conv1.weight.data = self.in_layers[-1].weight.data
|
||||
self.conv1.bias.data = self.in_layers[-1].bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
|
||||
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
|
||||
|
||||
self.norm2.weight.data = self.out_layers[0].weight.data
|
||||
self.norm2.bias.data = self.out_layers[0].bias.data
|
||||
|
||||
self.conv2.weight.data = self.out_layers[-1].weight.data
|
||||
self.conv2.bias.data = self.out_layers[-1].bias.data
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
||||
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=1.0):
|
||||
if self.overwrite_for_grad_tts and not self.is_overwritten:
|
||||
self.set_weights_grad_tts()
|
||||
self.is_overwritten = True
|
||||
elif self.overwrite_for_ldm and not self.is_overwritten:
|
||||
self.set_weights_ldm()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
if self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
|
|||
if not self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
|
|||
if not self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
h = h * mask
|
||||
|
||||
x = x * mask if mask is not None else x
|
||||
x = x * mask
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
|
|
|
@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
|
||||
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
|
||||
|
||||
|
||||
# from .resnet import ResBlock
|
||||
|
||||
|
||||
def exists(val):
|
||||
|
@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
|
@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
|
@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=out_ch,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# down=True,
|
||||
# )
|
||||
None
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
|
@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
|
@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
if not use_spatial_transformer
|
||||
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
ResnetBlock(
|
||||
in_channels=ch + ich,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
|
@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=out_ch,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# up=True,
|
||||
# )
|
||||
None
|
||||
if resblock_updown
|
||||
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
|
@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module):
|
|||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=model_channels * mult,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
|
@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module):
|
|||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
# ResBlock(
|
||||
# ch,
|
||||
# time_embed_dim,
|
||||
# dropout,
|
||||
# out_channels=out_ch,
|
||||
# dims=dims,
|
||||
# use_checkpoint=use_checkpoint,
|
||||
# use_scale_shift_norm=use_scale_shift_norm,
|
||||
# down=True,
|
||||
# )
|
||||
None
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
|
||||
|
@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module):
|
|||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
|
@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module):
|
|||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
ResnetBlock(
|
||||
in_channels=ch,
|
||||
out_channels=None,
|
||||
dropout=dropout,
|
||||
temb_channels=time_embed_dim,
|
||||
eps=1e-5,
|
||||
non_linearity="silu",
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
|
Loading…
Reference in New Issue