Merge pull request #46 from huggingface/merge_ldm_resnet

[ResNet Refactor] Merge ldm into resnet
This commit is contained in:
Patrick von Platen 2022-06-29 19:34:13 +02:00 committed by GitHub
commit b65eb377dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 271 additions and 92 deletions

View File

@ -162,7 +162,7 @@ class Downsample(nn.Module):
# RESNETS
# unet_glide.py & unet_ldm.py
# unet_glide.py
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
@ -188,6 +188,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
overwrite=False, # TODO(Patrick) - use for glide at later stage
):
super().__init__()
self.channels = channels
@ -236,6 +237,65 @@ class ResBlock(TimestepBlock):
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.overwrite = overwrite
self.is_overwritten = False
if self.overwrite:
in_channels = channels
out_channels = self.out_channels
conv_shortcut = False
dropout = 0.0
temb_channels = emb_channels
groups = 32
pre_norm = True
eps = 1e-5
non_linearity = "silu"
self.pre_norm = pre_norm
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def set_weights(self):
# TODO(Patrick): use for glide at later stage
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
@ -243,6 +303,10 @@ class ResBlock(TimestepBlock):
:param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
if self.overwrite:
# TODO(Patrick): use for glide at later stage
self.set_weights()
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
@ -251,6 +315,7 @@ class ResBlock(TimestepBlock):
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
@ -262,7 +327,50 @@ class ResBlock(TimestepBlock):
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
result = self.skip_connection(x) + h
# TODO(Patrick) Use for glide at later stage
# result = self.forward_2(x, emb)
return result
def forward_2(self, x, temb, mask=1.0):
if self.overwrite and not self.is_overwritten:
self.set_weights()
self.is_overwritten = True
h = x
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
# unet.py and unet_grad_tts.py
@ -280,6 +388,7 @@ class ResnetBlock(nn.Module):
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
):
super().__init__()
self.pre_norm = pre_norm
@ -302,15 +411,19 @@ class ResnetBlock(nn.Module):
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
# TODO(Patrick) - this branch is never used I think => can be deleted!
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
@ -324,6 +437,39 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
nn.Identity(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),
nn.SiLU() if use_scale_shift_norm else nn.Identity(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == in_channels:
self.skip_connection = nn.Identity()
# elif use_conv:
# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
@ -343,13 +489,36 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None):
def set_weights_ldm(self):
self.norm1.weight.data = self.in_layers[0].weight.data
self.norm1.bias.data = self.in_layers[0].bias.data
self.conv1.weight.data = self.in_layers[-1].weight.data
self.conv1.bias.data = self.in_layers[-1].bias.data
self.temb_proj.weight.data = self.emb_layers[-1].weight.data
self.temb_proj.bias.data = self.emb_layers[-1].bias.data
self.norm2.weight.data = self.out_layers[0].weight.data
self.norm2.bias.data = self.out_layers[0].bias.data
self.conv2.weight.data = self.out_layers[-1].weight.data
self.conv2.bias.data = self.out_layers[-1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def forward(self, x, temb, mask=1.0):
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
h = x
h = h * mask if mask is not None else h
h = h * mask
if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
@ -359,11 +528,11 @@ class ResnetBlock(nn.Module):
if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h * mask
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask if mask is not None else h
h = h * mask
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
@ -374,9 +543,9 @@ class ResnetBlock(nn.Module):
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h * mask
x = x * mask if mask is not None else x
x = x * mask
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)

View File

@ -10,7 +10,10 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding
from .resnet import Downsample, ResBlock, TimestepBlock, Upsample
from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample
# from .resnet import ResBlock
def exists(val):
@ -364,7 +367,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
@ -559,14 +562,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
ResnetBlock(
in_channels=ch,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
)
]
ch = mult * model_channels
@ -599,16 +602,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if resblock_updown
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
@ -629,13 +633,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
ResnetBlock(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
AttentionBlock(
ch,
@ -646,13 +651,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
)
if not use_spatial_transformer
else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
ResnetBlock(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch
@ -662,15 +668,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
ResnetBlock(
in_channels=ch + ich,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = model_channels * mult
if ds in attention_resolutions:
@ -698,16 +704,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# up=True,
# )
None
if resblock_updown
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
@ -842,15 +849,15 @@ class EncoderUNetModel(nn.Module):
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
ResnetBlock(
in_channels=ch,
out_channels=model_channels * mult,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
]
ch = mult * model_channels
if ds in attention_resolutions:
@ -870,16 +877,17 @@ class EncoderUNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
# ResBlock(
# ch,
# time_embed_dim,
# dropout,
# out_channels=out_ch,
# dims=dims,
# use_checkpoint=use_checkpoint,
# use_scale_shift_norm=use_scale_shift_norm,
# down=True,
# )
None
if resblock_updown
else Downsample(
ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op"
@ -892,13 +900,14 @@ class EncoderUNetModel(nn.Module):
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
ResnetBlock(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
AttentionBlock(
ch,
@ -907,13 +916,14 @@ class EncoderUNetModel(nn.Module):
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
ResnetBlock(
in_channels=ch,
out_channels=None,
dropout=dropout,
temb_channels=time_embed_dim,
eps=1e-5,
non_linearity="silu",
overwrite_for_ldm=True,
),
)
self._feature_size += ch