From dcb9070bc28401403d55b8c2097ea54a67e52ddd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 Jul 2022 17:56:59 +0000 Subject: [PATCH] quick fix to include non-fir kernels for sde-vp --- src/diffusers/models/resnet.py | 114 +++--------------- .../models/unet_sde_score_estimation.py | 4 +- 2 files changed, 18 insertions(+), 100 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 82079d4b..4481e533 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -237,24 +237,23 @@ class ResnetBlock(nn.Module): elif non_linearity == "silu": self.nonlinearity = nn.SiLU() - # if up: - # self.h_upd = Upsample(in_channels, use_conv=False, dims=2) - # self.x_upd = Upsample(in_channels, use_conv=False, dims=2) - # elif down: - # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - self.upsample = self.downsample = None - if self.up and kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, k=fir_kernel) - elif self.up and kernel is None: - self.upsample = Upsample(in_channels, use_conv=False, dims=2) - elif self.down and kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, k=fir_kernel) - elif self.down and kernel is None: - self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, k=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") + else: + self.upsample = Upsample(in_channels, use_conv=False, dims=2) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, k=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut @@ -473,87 +472,6 @@ class Block(torch.nn.Module): ) -# unet_score_estimation.py -class ResnetBlockBigGANpp(nn.Module): - def __init__( - self, - act, - in_ch, - out_ch=None, - temb_dim=None, - up=False, - down=False, - dropout=0.1, - fir=False, - fir_kernel=(1, 3, 3, 1), - skip_rescale=True, - init_scale=0.0, - ): - super().__init__() - - out_ch = out_ch if out_ch else in_ch - self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) - self.up = up - self.down = down - self.fir = fir - self.fir_kernel = fir_kernel - - if self.up: - if self.fir: - self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2) - else: - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - elif self.down: - if self.fir: - self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2) - else: - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - - self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) - if temb_dim is not None: - self.Dense_0 = nn.Linear(temb_dim, out_ch) - self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape) - nn.init.zeros_(self.Dense_0.bias) - - self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) - self.Dropout_0 = nn.Dropout(dropout) - self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1) - if in_ch != out_ch or up or down: - # 1x1 convolution with DDPM initialization. - self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0) - - self.skip_rescale = skip_rescale - self.act = act - self.in_ch = in_ch - self.out_ch = out_ch - - def forward(self, x, temb=None): - h = self.act(self.GroupNorm_0(x)) - - if self.up: - h = self.upsample(h) - x = self.upsample(x) - elif self.down: - h = self.downsample(h) - x = self.downsample(x) - - h = self.Conv_0(h) - # Add bias to each feature map conditioned on the time embedding - if temb is not None: - h += self.Dense_0(self.act(temb))[:, :, None, None] - h = self.act(self.GroupNorm_1(h)) - h = self.Dropout_0(h) - h = self.Conv_1(h) - - if self.in_ch != self.out_ch or self.up or self.down: - x = self.Conv_2(x) - - if not self.skip_rescale: - return x + h - else: - return (x + h) / np.sqrt(2.0) - - # unet_rl.py class ResidualTemporalBlock(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index f4c7f66c..bd01432f 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -373,7 +373,7 @@ class NCSNpp(ModelMixin, ConfigMixin): groups_out=min(out_ch // 4, 32), overwrite_for_score_vde=True, down=True, - kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine + kernel="fir" if self.fir else "sde_vp", use_nin_shortcut=True, ) ) @@ -473,7 +473,7 @@ class NCSNpp(ModelMixin, ConfigMixin): groups_out=min(out_ch // 4, 32), overwrite_for_score_vde=True, up=True, - kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine + kernel="fir" if self.fir else "sde_vp", use_nin_shortcut=True, ) )