From 358531be9d858d41077c7e3ebe02a44df6261487 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Jun 2022 17:30:35 +0000 Subject: [PATCH] up --- src/diffusers/models/resnet.py | 185 +++++++++++++++++++++++++++++-- src/diffusers/models/unet_ldm.py | 179 ++++++++++++++++-------------- 2 files changed, 271 insertions(+), 93 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 49c15642..8972e58e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -162,7 +162,7 @@ class Downsample(nn.Module): # RESNETS -# unet_glide.py & unet_ldm.py +# unet_glide.py class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. @@ -188,6 +188,7 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, + overwrite=False, # TODO(Patrick) - use for glide at later stage ): super().__init__() self.channels = channels @@ -236,6 +237,65 @@ class ResBlock(TimestepBlock): else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.overwrite = overwrite + self.is_overwritten = False + if self.overwrite: + in_channels = channels + out_channels = self.out_channels + conv_shortcut = False + dropout = 0.0 + temb_channels = emb_channels + groups = 32 + pre_norm = True + eps = 1e-5 + non_linearity = "silu" + self.pre_norm = pre_norm + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if self.pre_norm: + self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) + else: + self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) + + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if non_linearity == "swish": + self.nonlinearity = nonlinearity + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def set_weights(self): + # TODO(Patrick): use for glide at later stage + self.norm1.weight.data = self.in_layers[0].weight.data + self.norm1.bias.data = self.in_layers[0].bias.data + + self.conv1.weight.data = self.in_layers[-1].weight.data + self.conv1.bias.data = self.in_layers[-1].bias.data + + self.temb_proj.weight.data = self.emb_layers[-1].weight.data + self.temb_proj.bias.data = self.emb_layers[-1].bias.data + + self.norm2.weight.data = self.out_layers[0].weight.data + self.norm2.bias.data = self.out_layers[0].bias.data + + self.conv2.weight.data = self.out_layers[-1].weight.data + self.conv2.bias.data = self.out_layers[-1].bias.data + + if self.in_channels != self.out_channels: + self.nin_shortcut.weight.data = self.skip_connection.weight.data + self.nin_shortcut.bias.data = self.skip_connection.bias.data + def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. @@ -243,6 +303,10 @@ class ResBlock(TimestepBlock): :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ + if self.overwrite: + # TODO(Patrick): use for glide at later stage + self.set_weights() + if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) @@ -251,6 +315,7 @@ class ResBlock(TimestepBlock): h = in_conv(h) else: h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] @@ -262,7 +327,50 @@ class ResBlock(TimestepBlock): else: h = h + emb_out h = self.out_layers(h) - return self.skip_connection(x) + h + + result = self.skip_connection(x) + h + +# TODO(Patrick) Use for glide at later stage +# result = self.forward_2(x, emb) + + return result + + def forward_2(self, x, temb, mask=1.0): + if self.overwrite and not self.is_overwritten: + self.set_weights() + self.is_overwritten = True + + h = x + if self.pre_norm: + h = self.norm1(h) + h = self.nonlinearity(h) + + h = self.conv1(h) + + if not self.pre_norm: + h = self.norm1(h) + h = self.nonlinearity(h) + + h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + + if self.pre_norm: + h = self.norm2(h) + h = self.nonlinearity(h) + + h = self.dropout(h) + h = self.conv2(h) + + if not self.pre_norm: + h = self.norm2(h) + h = self.nonlinearity(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h # unet.py and unet_grad_tts.py @@ -280,6 +388,7 @@ class ResnetBlock(nn.Module): eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False, + overwrite_for_ldm=False, ): super().__init__() self.pre_norm = pre_norm @@ -302,15 +411,19 @@ class ResnetBlock(nn.Module): self.nonlinearity = nonlinearity elif non_linearity == "mish": self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() if self.in_channels != self.out_channels: if self.use_conv_shortcut: + # TODO(Patrick) - this branch is never used I think => can be deleted! self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.is_overwritten = False self.overwrite_for_grad_tts = overwrite_for_grad_tts + self.overwrite_for_ldm = overwrite_for_ldm if self.overwrite_for_grad_tts: dim = in_channels dim_out = out_channels @@ -324,6 +437,39 @@ class ResnetBlock(nn.Module): self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) else: self.res_conv = torch.nn.Identity() + elif self.overwrite_for_ldm: + dims = 2 +# eps = 1e-5 +# non_linearity = "silu" +# overwrite_for_ldm + channels = in_channels + emb_channels = temb_channels + use_scale_shift_norm = False + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), + nn.SiLU() if use_scale_shift_norm else nn.Identity(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() +# elif use_conv: +# self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def set_weights_grad_tts(self): self.conv1.weight.data = self.block1.block[0].weight.data @@ -343,13 +489,36 @@ class ResnetBlock(nn.Module): self.nin_shortcut.weight.data = self.res_conv.weight.data self.nin_shortcut.bias.data = self.res_conv.bias.data - def forward(self, x, temb, mask=None): + def set_weights_ldm(self): + self.norm1.weight.data = self.in_layers[0].weight.data + self.norm1.bias.data = self.in_layers[0].bias.data + + self.conv1.weight.data = self.in_layers[-1].weight.data + self.conv1.bias.data = self.in_layers[-1].bias.data + + self.temb_proj.weight.data = self.emb_layers[-1].weight.data + self.temb_proj.bias.data = self.emb_layers[-1].bias.data + + self.norm2.weight.data = self.out_layers[0].weight.data + self.norm2.bias.data = self.out_layers[0].bias.data + + self.conv2.weight.data = self.out_layers[-1].weight.data + self.conv2.bias.data = self.out_layers[-1].bias.data + + if self.in_channels != self.out_channels: + self.nin_shortcut.weight.data = self.skip_connection.weight.data + self.nin_shortcut.bias.data = self.skip_connection.bias.data + + def forward(self, x, temb, mask=1.0): if self.overwrite_for_grad_tts and not self.is_overwritten: self.set_weights_grad_tts() self.is_overwritten = True + elif self.overwrite_for_ldm and not self.is_overwritten: + self.set_weights_ldm() + self.is_overwritten = True h = x - h = h * mask if mask is not None else h + h = h * mask if self.pre_norm: h = self.norm1(h) h = self.nonlinearity(h) @@ -359,11 +528,11 @@ class ResnetBlock(nn.Module): if not self.pre_norm: h = self.norm1(h) h = self.nonlinearity(h) - h = h * mask if mask is not None else h + h = h * mask h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - h = h * mask if mask is not None else h + h = h * mask if self.pre_norm: h = self.norm2(h) h = self.nonlinearity(h) @@ -374,9 +543,9 @@ class ResnetBlock(nn.Module): if not self.pre_norm: h = self.norm2(h) h = self.nonlinearity(h) - h = h * mask if mask is not None else h + h = h * mask - x = x * mask if mask is not None else x + x = x * mask if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 0571013d..f78f3afd 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -10,7 +10,9 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import get_timestep_embedding -from .resnet import Downsample, ResBlock, TimestepBlock, Upsample +from .resnet import Downsample, TimestepBlock, Upsample +from .resnet import ResnetBlock +#from .resnet import ResBlock def exists(val): @@ -364,7 +366,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): def forward(self, x, emb, context=None): for layer in self: - if isinstance(layer, TimestepBlock): + if isinstance(layer, TimestepBlock) or isinstance(layer, ResnetBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context) @@ -559,14 +561,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + ResnetBlock( + in_channels=ch, + out_channels=mult * model_channels, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ) ] ch = mult * model_channels @@ -599,16 +601,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=out_ch, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# down=True, +# ) + None if resblock_updown else Downsample( ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" @@ -629,13 +632,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + ResnetBlock( + in_channels=ch, + out_channels=None, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ), AttentionBlock( ch, @@ -646,13 +650,14 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + ResnetBlock( + in_channels=ch, + out_channels=None, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ), ) self._feature_size += ch @@ -662,15 +667,15 @@ class UNetLDMModel(ModelMixin, ConfigMixin): for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, + ResnetBlock( + in_channels=ch + ich, out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, + ), ] ch = model_channels * mult if ds in attention_resolutions: @@ -698,16 +703,17 @@ class UNetLDMModel(ModelMixin, ConfigMixin): if level and i == num_res_blocks: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=out_ch, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# up=True, +# ) + None if resblock_updown else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch) ) @@ -842,15 +848,15 @@ class EncoderUNetModel(nn.Module): for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) + ResnetBlock( + in_channels=ch, + out_channels=model_channels * mult, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, + ), ] ch = mult * model_channels if ds in attention_resolutions: @@ -870,16 +876,17 @@ class EncoderUNetModel(nn.Module): out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) +# ResBlock( +# ch, +# time_embed_dim, +# dropout, +# out_channels=out_ch, +# dims=dims, +# use_checkpoint=use_checkpoint, +# use_scale_shift_norm=use_scale_shift_norm, +# down=True, +# ) + None if resblock_updown else Downsample( ch, use_conv=conv_resample, dims=dims, out_channels=out_ch, padding=1, name="op" @@ -892,13 +899,14 @@ class EncoderUNetModel(nn.Module): self._feature_size += ch self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + ResnetBlock( + in_channels=ch, + out_channels=None, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ), AttentionBlock( ch, @@ -907,13 +915,14 @@ class EncoderUNetModel(nn.Module): num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, + ResnetBlock( + in_channels=ch, + out_channels=None, + dropout=dropout, + temb_channels=time_embed_dim, + eps=1e-5, + non_linearity="silu", + overwrite_for_ldm=True, ), ) self._feature_size += ch