diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index bdee6946..42a92623 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,6 +1,5 @@ from abc import abstractmethod -import functools import numpy as np import torch import torch.nn as nn @@ -160,211 +159,7 @@ class Downsample(nn.Module): # return self.conv(x) -# RESNETS -# unet_score_estimation.py -class ResnetBlockBigGANppNew(nn.Module): - def __init__( - self, - act, - in_ch, - out_ch=None, - temb_dim=None, - up=False, - down=False, - dropout=0.1, - fir_kernel=(1, 3, 3, 1), - skip_rescale=True, - init_scale=0.0, - overwrite=True, - ): - super().__init__() - - out_ch = out_ch if out_ch else in_ch - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) - self.up = up - self.down = down - self.fir_kernel = fir_kernel - - self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape) - nn.init.zeros_(self.Dense_0.bias) - - self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1) - if in_ch != out_ch or up or down: - # 1x1 convolution with DDPM initialization. - self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) - - self.skip_rescale = skip_rescale - self.act = act - self.in_ch = in_ch - self.out_ch = out_ch - - self.is_overwritten = False - self.overwrite = overwrite - if overwrite: - self.output_scale_factor = np.sqrt(2.0) - self.in_channels = in_channels = in_ch - self.out_channels = out_channels = out_ch - groups = min(in_ch // 4, 32) - out_groups = min(out_ch // 4, 32) - eps = 1e-6 - self.pre_norm = True - temb_channels = temb_dim - non_linearity = "silu" - self.time_embedding_norm = time_embedding_norm = "default" - - if self.pre_norm: - self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) - else: - self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if time_embedding_norm == "default": - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - elif time_embedding_norm == "scale_shift": - self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - - self.norm2 = Normalize(out_channels, num_groups=out_groups, eps=eps) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = nonlinearity - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - if up: - self.h_upd = Upsample(in_channels, use_conv=False, dims=2) - self.x_upd = Upsample(in_channels, use_conv=False, dims=2) - elif down: - self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - - if self.in_channels != self.out_channels or self.up or self.down: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def set_weights(self): - self.conv1.weight.data = self.Conv_0.weight.data - self.conv1.bias.data = self.Conv_0.bias.data - self.norm1.weight.data = self.GroupNorm_0.weight.data - self.norm1.bias.data = self.GroupNorm_0.bias.data - - self.conv2.weight.data = self.Conv_1.weight.data - self.conv2.bias.data = self.Conv_1.bias.data - self.norm2.weight.data = self.GroupNorm_1.weight.data - self.norm2.bias.data = self.GroupNorm_1.bias.data - - self.temb_proj.weight.data = self.Dense_0.weight.data - self.temb_proj.bias.data = self.Dense_0.bias.data - - if self.in_channels != self.out_channels or self.up or self.down: - self.nin_shortcut.weight.data = self.Conv_2.weight.data - self.nin_shortcut.bias.data = self.Conv_2.bias.data - - def forward(self, x, temb=None): - if self.overwrite and not self.is_overwritten: - self.set_weights() - self.is_overwritten = True - - orig_x = x - h = self.act(self.GroupNorm_0(x)) - - if self.up: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) - elif self.down: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) - - h = self.Conv_0(h) - # Add bias to each feature map conditioned on the time embedding - if temb is not None: - h += self.Dense_0(self.act(temb))[:, :, None, None] - h = self.act(self.GroupNorm_1(h)) - h = self.Dropout_0(h) - h = self.Conv_1(h) - - if self.in_ch != self.out_ch or self.up or self.down: - x = self.Conv_2(x) - - if not self.skip_rescale: - raise ValueError("Is this branch run?!") -# import ipdb; ipdb.set_trace() - result = x + h - else: - result = (x + h) / np.sqrt(2.0) - - result_2 = self.forward_2(orig_x, temb) - - return result_2 - - def forward_2(self, x, temb, mask=1.0): - h = x - h = h * mask - if self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - -# if self.up or self.down: -# x = self.x_upd(x) -# h = self.h_upd(h) - if self.up: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) - elif self.down: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) - - h = self.conv1(h) - - if not self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - h = h * mask - - temb = self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - - if self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - - h = self.norm2(h) - h = h + h * scale + shift - h = self.nonlinearity(h) - elif self.time_embedding_norm == "default": - h = h + temb - h = h * mask - if self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - else: - raise ValueError("Nananan nanana - don't go here!") - - h = self.dropout(h) - h = self.conv2(h) - - if not self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - h = h * mask - - x = x * mask -# if self.in_channels != self.out_channels: - if self.in_channels != self.out_channels or self.up or self.down: - x = self.nin_shortcut(x) - - result = x + h - - return result / self.output_scale_factor - - -# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py +# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py class ResnetBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index c5a15794..c915dede 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -28,7 +28,7 @@ from ..modeling_utils import ModelMixin from .attention import AttentionBlock from .embeddings import GaussianFourierProjection, get_timestep_embedding from .resnet import downsample_2d, upfirdn2d, upsample_2d -from .resnet import ResnetBlock +from .resnet import ResnetBlock def _setup_kernel(k): @@ -464,7 +464,7 @@ class NCSNpp(ModelMixin, ConfigMixin): groups_out=min(out_ch // 4, 32), overwrite_for_score_vde=True, up=True, - kernel="fir", + kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine use_nin_shortcut=True, ) )