From fa7443c899648525f3eb477faea84cac52baf5f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 1 Jul 2022 15:39:57 +0000 Subject: [PATCH] finish resnet --- src/diffusers/models/resnet.py | 23 +++-- .../models/unet_sde_score_estimation.py | 94 ++++++++++++++++--- 2 files changed, 92 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index b0adaedb..bdee6946 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -380,7 +380,7 @@ class ResnetBlock(nn.Module): eps=1e-6, non_linearity="swish", time_embedding_norm="default", - fir_kernel=(1, 3, 3, 1), + kernel=None, output_scale_factor=1.0, use_nin_shortcut=None, up=False, @@ -433,8 +433,18 @@ class ResnetBlock(nn.Module): # elif down: # self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") # self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") - self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None - self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None + + self.upsample = self.downsample = None + if self.up and kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, k=fir_kernel) + elif self.up and kernel is None: + self.upsample = Upsample(in_channels, use_conv=False, dims=2) + elif self.down and kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, k=fir_kernel) + elif self.down and kernel is None: + self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut @@ -505,8 +515,6 @@ class ResnetBlock(nn.Module): self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps) self.up = up self.down = down - self.fir_kernel = fir_kernel - self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) @@ -525,11 +533,6 @@ class ResnetBlock(nn.Module): self.out_ch = out_ch # TODO(Patrick) - move to main init - if self.up: - self.upsample = functools.partial(upsample_2d, k=self.fir_kernel) - if self.down: - self.downsample = functools.partial(downsample_2d, k=self.fir_kernel) - self.is_overwritten = False def set_weights_grad_tts(self): diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index e9525d7a..727bec30 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -348,16 +348,18 @@ class NCSNpp(ModelMixin, ConfigMixin): for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] # modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) - modules.append(ResnetNew( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - )) + modules.append( + ResnetNew( + in_channels=in_ch, + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: @@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) if i_level != self.num_resolutions - 1: - modules.append(ResnetBlock(down=True, in_ch=in_ch)) +# modules.append(ResnetBlock(down=True, in_ch=in_ch)) + modules.append( + ResnetNew( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + down=True, + kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine + use_nin_shortcut=True, + ) + ) if progressive_input == "input_skip": modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) @@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin): hs_c.append(in_ch) in_ch = hs_c[-1] - modules.append(ResnetBlock(in_ch=in_ch)) +# modules.append(ResnetBlock(in_ch=in_ch)) + modules.append( + ResnetNew( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) modules.append(AttnBlock(channels=in_ch)) - modules.append(ResnetBlock(in_ch=in_ch)) +# modules.append(ResnetBlock(in_ch=in_ch)) + modules.append( + ResnetNew( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) pyramid_ch = 0 # Upsampling block for i_level in reversed(range(self.num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] - modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) +# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + modules.append( + ResnetNew( + in_channels=in_ch + hs_c.pop(), + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + ) + ) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: @@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin): raise ValueError(f"{progressive} is not a valid name") if i_level != 0: - modules.append(ResnetBlock(in_ch=in_ch, up=True)) +# modules.append(ResnetBlock(in_ch=in_ch, up=True)) + modules.append( + ResnetNew( + in_channels=in_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + up=True, + kernel="fir", + use_nin_shortcut=True, + ) + ) assert not hs_c