From efe1e60e12d07ef8a32db7e43935e6bd9ea74904 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Jun 2022 22:24:22 +0000 Subject: [PATCH] merge glide into resnets --- src/diffusers/models/resnet.py | 243 +---------------------------- src/diffusers/models/unet_glide.py | 73 ++------- tests/test_modeling_utils.py | 2 +- 3 files changed, 16 insertions(+), 302 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f95d9198..f48a9403 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -161,229 +161,7 @@ class Downsample(nn.Module): # RESNETS -# unet_glide.py -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - - :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param - use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing - on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for - downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - overwrite=True, # TODO(Patrick) - use for glide at later stage - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels, swish=1.0), - nn.Identity(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, use_conv=False, dims=dims) - self.x_upd = Upsample(channels, use_conv=False, dims=dims) - elif down: - self.h_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - self.x_upd = Downsample(channels, use_conv=False, dims=dims, padding=1, name="op") - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels, swish=0.0), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - self.overwrite = overwrite - self.is_overwritten = False - if self.overwrite: - in_channels = channels - out_channels = self.out_channels - conv_shortcut = False - dropout = 0.0 - temb_channels = emb_channels - groups = 32 - pre_norm = True - eps = 1e-5 - non_linearity = "silu" - self.pre_norm = pre_norm - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - # Add to init - self.time_embedding_norm = "scale_shift" - - if self.pre_norm: - self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) - else: - self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if non_linearity == "swish": - self.nonlinearity = nonlinearity - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - if self.in_channels != self.out_channels: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - self.up, self.down = up, down -# if self.up: -# self.h_upd = Upsample(in_channels, use_conv=False, dims=dims) -# self.x_upd = Upsample(in_channels, use_conv=False, dims=dims) -# elif self.down: -# self.h_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op") -# self.x_upd = Downsample(in_channels, use_conv=False, dims=dims, padding=1, name="op") - - def set_weights(self): - # TODO(Patrick): use for glide at later stage - self.norm1.weight.data = self.in_layers[0].weight.data - self.norm1.bias.data = self.in_layers[0].bias.data - - self.conv1.weight.data = self.in_layers[-1].weight.data - self.conv1.bias.data = self.in_layers[-1].bias.data - - self.temb_proj.weight.data = self.emb_layers[-1].weight.data - self.temb_proj.bias.data = self.emb_layers[-1].bias.data - - self.norm2.weight.data = self.out_layers[0].weight.data - self.norm2.bias.data = self.out_layers[0].bias.data - - self.conv2.weight.data = self.out_layers[-1].weight.data - self.conv2.bias.data = self.out_layers[-1].bias.data - - if self.in_channels != self.out_channels: - self.nin_shortcut.weight.data = self.skip_connection.weight.data - self.nin_shortcut.bias.data = self.skip_connection.bias.data - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - if self.overwrite: - # TODO(Patrick): use for glide at later stage - self.set_weights() - - orig_x = x - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - - result = self.skip_connection(x) + h - - # TODO(Patrick) Use for glide at later stage - result = self.forward_2(orig_x, emb) - return result - - def forward_2(self, x, temb): - if self.overwrite and not self.is_overwritten: - self.set_weights() - self.is_overwritten = True - - h = x - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.up or self.down: - x = self.x_upd(x) - h = self.h_upd(h) - - h = self.conv1(h) - - temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - - if self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - - h = self.norm2(h) - h = h + h * scale + shift - h = self.nonlinearity(h) - else: - h = h + temb - h = self.norm2(h) - h = self.nonlinearity(h) - - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x + h - - -# unet.py, unet_grad_tts.py, unet_ldm.py +# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py class ResnetBlock(nn.Module): def __init__( self, @@ -445,12 +223,9 @@ class ResnetBlock(nn.Module): self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - # TODO(Patrick) - this branch is never used I think => can be deleted! - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED self.is_overwritten = False self.overwrite_for_glide = overwrite_for_glide self.overwrite_for_grad_tts = overwrite_for_grad_tts @@ -497,8 +272,6 @@ class ResnetBlock(nn.Module): ) if self.out_channels == in_channels: self.skip_connection = nn.Identity() - # elif use_conv: - # self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -541,6 +314,8 @@ class ResnetBlock(nn.Module): self.nin_shortcut.bias.data = self.skip_connection.bias.data def forward(self, x, temb, mask=1.0): + # TODO(Patrick) eventually this class should be split into multiple classes + # too many if else statements if self.overwrite_for_grad_tts and not self.is_overwritten: self.set_weights_grad_tts() self.is_overwritten = True @@ -566,6 +341,7 @@ class ResnetBlock(nn.Module): h = h * mask temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + if self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) @@ -589,9 +365,6 @@ class ResnetBlock(nn.Module): x = x * mask if self.in_channels != self.out_channels: -# if self.use_conv_shortcut: -# x = self.conv_shortcut(x) -# else: x = self.nin_shortcut(x) return x + h @@ -605,10 +378,6 @@ class Block(torch.nn.Module): torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() ) - def forward(self, x, mask): - output = self.block(x * mask) - return output * mask - # unet_score_estimation.py class ResnetBlockBigGANpp(nn.Module): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index a0af4b9f..a7450797 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -6,8 +6,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResBlock, TimestepBlock, Upsample -from .resnet import ResnetBlock +from .resnet import Downsample, ResnetBlock, TimestepBlock, Upsample def convert_module_to_f16(l): @@ -191,15 +190,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=int(mult * model_channels), -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# ) ResnetBlock( in_channels=ch, out_channels=mult * model_channels, @@ -207,7 +197,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, ) ] @@ -229,16 +219,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=out_ch, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# down=True, -# ) ResnetBlock( in_channels=ch, out_channels=out_ch, @@ -246,9 +226,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, - down=True + down=True, ) if resblock_updown else Downsample( @@ -262,21 +242,13 @@ class GlideUNetModel(ModelMixin, ConfigMixin): self._feature_size += ch self.middle_block = TimestepEmbedSequential( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# ), ResnetBlock( in_channels=ch, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, ), AttentionBlock( @@ -286,23 +258,15 @@ class GlideUNetModel(ModelMixin, ConfigMixin): num_head_channels=num_head_channels, encoder_channels=transformer_dim, ), -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# ), ResnetBlock( in_channels=ch, dropout=dropout, temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, - ) + ), ) self._feature_size += ch @@ -311,15 +275,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ -# ResBlock( -# ch + ich, -# time_embed_dim, -# dropout, -# out_channels=int(model_channels * mult), -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# ) ResnetBlock( in_channels=ch + ich, out_channels=model_channels * mult, @@ -327,7 +282,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, ), ] @@ -345,16 +300,6 @@ class GlideUNetModel(ModelMixin, ConfigMixin): if level and i == num_res_blocks: out_ch = ch layers.append( -# ResBlock( -# ch, -# time_embed_dim, -# dropout, -# out_channels=out_ch, -# dims=dims, -# use_checkpoint=use_checkpoint, -# use_scale_shift_norm=use_scale_shift_norm, -# up=True, -# ) ResnetBlock( in_channels=ch, out_channels=out_ch, @@ -362,7 +307,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin): temb_channels=time_embed_dim, eps=1e-5, non_linearity="silu", - time_embedding_norm="scale_shift", + time_embedding_norm="scale_shift" if use_scale_shift_norm else "default", overwrite_for_glide=True, up=True, ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ff37e8ab..1a410b93 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -795,7 +795,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): sizes = (32, 32) noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [9.]).to(torch_device) + time_step = torch.tensor(batch_size * [9.0]).to(torch_device) with torch.no_grad(): output = model(noise, time_step)