diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 46824a85..b0adaedb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -436,10 +436,9 @@ class ResnetBlock(nn.Module): self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None - self.nin_shortcut = use_nin_shortcut - if self.use_nin_shortcut is None: - self.use_nin_shortcut = self.in_channels != self.out_channels + self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut + self.nin_shortcut = None if self.use_nin_shortcut: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) @@ -526,8 +525,10 @@ class ResnetBlock(nn.Module): self.out_ch = out_ch # TODO(Patrick) - move to main init - self.upsample = functools.partial(upsample_2d, k=self.fir_kernel) - self.downsample = functools.partial(downsample_2d, k=self.fir_kernel) + if self.up: + self.upsample = functools.partial(upsample_2d, k=self.fir_kernel) + if self.down: + self.downsample = functools.partial(downsample_2d, k=self.fir_kernel) self.is_overwritten = False @@ -647,7 +648,7 @@ class ResnetBlock(nn.Module): if self.nin_shortcut is not None: x = self.nin_shortcut(x) - return x + h + return (x + h) / self.output_scale_factor # TODO(Patrick) - just there to convert the weights; can delete afterward diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index c900ae78..e9525d7a 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -348,18 +348,16 @@ class NCSNpp(ModelMixin, ConfigMixin): for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] # modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) - modules.append( - ResnetNew( - in_channels=in_ch, - out_channels=out_ch, - temb_channels=4 * nf, - output_scale_factor=np.sqrt(2.0), - non_linearity="silu", - groups=min(in_ch // 4, 32), - groups_out=min(out_ch // 4, 32), - overwrite_for_score_vde=True, - ) - ) + modules.append(ResnetNew( + in_channels=in_ch, + out_channels=out_ch, + temb_channels=4 * nf, + output_scale_factor=np.sqrt(2.0), + non_linearity="silu", + groups=min(in_ch // 4, 32), + groups_out=min(out_ch // 4, 32), + overwrite_for_score_vde=True, + )) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: