From 6028d58cb004dbfc0af61f17b9104fe037486d6d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 23 Aug 2022 08:38:37 +0200 Subject: [PATCH] Remove dead code in `resnet.py` (#218) remove dead code in resnet.py Co-authored-by: ydshieh --- src/diffusers/models/resnet.py | 400 --------------------------------- 1 file changed, 400 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 15cf6e26..c61aa270 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -364,412 +364,12 @@ class ResnetBlock(nn.Module): return out - def set_weight(self, resnet): - self.norm1.weight.data = resnet.norm1.weight.data - self.norm1.bias.data = resnet.norm1.bias.data - - self.conv1.weight.data = resnet.conv1.weight.data - self.conv1.bias.data = resnet.conv1.bias.data - - if self.time_emb_proj is not None: - self.time_emb_proj.weight.data = resnet.temb_proj.weight.data - self.time_emb_proj.bias.data = resnet.temb_proj.bias.data - - self.norm2.weight.data = resnet.norm2.weight.data - self.norm2.bias.data = resnet.norm2.bias.data - - self.conv2.weight.data = resnet.conv2.weight.data - self.conv2.bias.data = resnet.conv2.bias.data - - if self.use_nin_shortcut: - self.conv_shortcut.weight.data = resnet.nin_shortcut.weight.data - self.conv_shortcut.bias.data = resnet.nin_shortcut.bias.data - - -# THE FOLLOWING SHOULD BE DELETED ONCE ALL CHECKPOITNS ARE CONVERTED - -# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py -# => All 2D-Resnets are included here now! -class ResnetBlock2D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, - output_scale_factor=1.0, - use_nin_shortcut=None, - up=False, - down=False, - overwrite_for_grad_tts=False, - overwrite_for_ldm=False, - overwrite_for_glide=False, - overwrite_for_score_vde=False, - ): - super().__init__() - self.pre_norm = pre_norm - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - if self.pre_norm: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if time_embedding_norm == "default" and temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - elif time_embedding_norm == "scale_shift" and temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, k=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, k=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") - - self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut - - self.nin_shortcut = None - if self.use_nin_shortcut: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - # TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED - self.is_overwritten = False - self.overwrite_for_glide = overwrite_for_glide - self.overwrite_for_grad_tts = overwrite_for_grad_tts - self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide - self.overwrite_for_score_vde = overwrite_for_score_vde - if self.overwrite_for_grad_tts: - dim = in_channels - dim_out = out_channels - time_emb_dim = temb_channels - self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) - self.pre_norm = pre_norm - - self.block1 = Block(dim, dim_out, groups=groups) - self.block2 = Block(dim_out, dim_out, groups=groups) - if dim != dim_out: - self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) - else: - self.res_conv = torch.nn.Identity() - elif self.overwrite_for_ldm: - channels = in_channels - emb_channels = temb_channels - use_scale_shift_norm = False - non_linearity = "silu" - - self.in_layers = nn.Sequential( - normalization(channels, swish=1.0), - nn.Identity(), - nn.Conv2d(channels, self.out_channels, 3, padding=1), - ) - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if self.time_embedding_norm == "scale_shift" else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0), - nn.SiLU() if use_scale_shift_norm else nn.Identity(), - nn.Dropout(p=dropout), - zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), - ) - if self.out_channels == in_channels: - self.skip_connection = nn.Identity() - else: - self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) - self.set_weights_ldm() - elif self.overwrite_for_score_vde: - in_ch = in_channels - out_ch = out_channels - - eps = 1e-6 - num_groups = min(in_ch // 4, 32) - num_groups_out = min(out_ch // 4, 32) - temb_dim = temb_channels - - self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) - self.up = up - self.down = down - self.Conv_0 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - nn.init.zeros_(self.Dense_0.bias) - - self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1) - if in_ch != out_ch or up or down: - # 1x1 convolution with DDPM initialization. - self.Conv_2 = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0) - - self.in_ch = in_ch - self.out_ch = out_ch - self.set_weights_score_vde() - - def set_weights_grad_tts(self): - self.conv1.weight.data = self.block1.block[0].weight.data - self.conv1.bias.data = self.block1.block[0].bias.data - self.norm1.weight.data = self.block1.block[1].weight.data - self.norm1.bias.data = self.block1.block[1].bias.data - - self.conv2.weight.data = self.block2.block[0].weight.data - self.conv2.bias.data = self.block2.block[0].bias.data - self.norm2.weight.data = self.block2.block[1].weight.data - self.norm2.bias.data = self.block2.block[1].bias.data - - self.temb_proj.weight.data = self.mlp[1].weight.data - self.temb_proj.bias.data = self.mlp[1].bias.data - - if self.in_channels != self.out_channels: - self.nin_shortcut.weight.data = self.res_conv.weight.data - self.nin_shortcut.bias.data = self.res_conv.bias.data - - def set_weights_ldm(self): - self.norm1.weight.data = self.in_layers[0].weight.data - self.norm1.bias.data = self.in_layers[0].bias.data - - self.conv1.weight.data = self.in_layers[-1].weight.data - self.conv1.bias.data = self.in_layers[-1].bias.data - - self.temb_proj.weight.data = self.emb_layers[-1].weight.data - self.temb_proj.bias.data = self.emb_layers[-1].bias.data - - self.norm2.weight.data = self.out_layers[0].weight.data - self.norm2.bias.data = self.out_layers[0].bias.data - - self.conv2.weight.data = self.out_layers[-1].weight.data - self.conv2.bias.data = self.out_layers[-1].bias.data - - if self.in_channels != self.out_channels: - self.nin_shortcut.weight.data = self.skip_connection.weight.data - self.nin_shortcut.bias.data = self.skip_connection.bias.data - - def set_weights_score_vde(self): - self.conv1.weight.data = self.Conv_0.weight.data - self.conv1.bias.data = self.Conv_0.bias.data - self.norm1.weight.data = self.GroupNorm_0.weight.data - self.norm1.bias.data = self.GroupNorm_0.bias.data - - self.conv2.weight.data = self.Conv_1.weight.data - self.conv2.bias.data = self.Conv_1.bias.data - self.norm2.weight.data = self.GroupNorm_1.weight.data - self.norm2.bias.data = self.GroupNorm_1.bias.data - - self.temb_proj.weight.data = self.Dense_0.weight.data - self.temb_proj.bias.data = self.Dense_0.bias.data - - if self.in_channels != self.out_channels or self.up or self.down: - self.nin_shortcut.weight.data = self.Conv_2.weight.data - self.nin_shortcut.bias.data = self.Conv_2.bias.data - - def forward(self, x, temb, hey=False, mask=1.0): - # TODO(Patrick) eventually this class should be split into multiple classes - # too many if else statements - if self.overwrite_for_grad_tts and not self.is_overwritten: - self.set_weights_grad_tts() - self.is_overwritten = True - # elif self.overwrite_for_score_vde and not self.is_overwritten: - # self.set_weights_score_vde() - # self.is_overwritten = True - - # h2 tensor(110029.2109) - # h3 tensor(49596.9492) - - h = x - - h = h * mask - if self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - - if self.upsample is not None: - x = self.upsample(x) - h = self.upsample(h) - elif self.downsample is not None: - x = self.downsample(x) - h = self.downsample(h) - - h = self.conv1(h) - - if not self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - h = h * mask - - if temb is not None: - temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - else: - temb = 0 - - if self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - - h = self.norm2(h) - h = h + h * scale + shift - h = self.nonlinearity(h) - elif self.time_embedding_norm == "default": - h = h + temb - h = h * mask - if self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - - h = self.dropout(h) - h = self.conv2(h) - - if not self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - h = h * mask - - x = x * mask - if self.nin_shortcut is not None: - x = self.nin_shortcut(x) - - out = (x + h) / self.output_scale_factor - - return out - - -# TODO(Patrick) - just there to convert the weights; can delete afterward -class Block(torch.nn.Module): - def __init__(self, dim, dim_out, groups=8): - super(Block, self).__init__() - self.block = torch.nn.Sequential( - torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() - ) - - -# HELPER Modules - - -def normalization(channels, swish=0.0): - """ - Make a standard normalization layer, with an optional swish activation. - - :param channels: number of input channels. :return: an nn.Module for normalization. - """ - return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) - - -class GroupNorm32(nn.GroupNorm): - def __init__(self, num_groups, num_channels, swish, eps=1e-5): - super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) - self.swish = swish - - def forward(self, x): - y = super().forward(x.float()).to(x.dtype) - if self.swish == 1.0: - y = F.silu(y) - elif self.swish: - y = y * F.sigmoid(y * float(self.swish)) - return y - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - class Mish(torch.nn.Module): def forward(self, x): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - RearrangeDim(), - # Rearrange("batch channels horizon -> batch channels 1 horizon"), - nn.GroupNorm(n_groups, out_channels), - RearrangeDim(), - # Rearrange("batch channels 1 horizon -> batch channels horizon"), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - - -class RearrangeDim(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - def upsample_2d(x, k=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter.