diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 93c0cf17..ae6754b1 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module): h = self.act(self.GroupNorm_0(x)) if self.up: - if self.fir: - h = upsample_2d(h, self.fir_kernel, factor=2) - x = upsample_2d(x, self.fir_kernel, factor=2) - else: - h = naive_upsample_2d(h, factor=2) - x = naive_upsample_2d(x, factor=2) + h = upsample_2d(h, self.fir_kernel, factor=2) + x = upsample_2d(x, self.fir_kernel, factor=2) elif self.down: - if self.fir: - h = downsample_2d(h, self.fir_kernel, factor=2) - x = downsample_2d(x, self.fir_kernel, factor=2) - else: - h = naive_downsample_2d(h, factor=2) - x = naive_downsample_2d(x, factor=2) + h = downsample_2d(h, self.fir_kernel, factor=2) + x = downsample_2d(x, self.fir_kernel, factor=2) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 8acf3372..508bac14 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -417,20 +417,16 @@ class Upsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch - if not fir: - if with_conv: - self.Conv_0 = conv3x3(in_ch, out_ch) - else: - if with_conv: - self.Conv2d_0 = Conv2d( - in_ch, - out_ch, - kernel=3, - up=True, - resample_kernel=fir_kernel, - use_bias=True, - kernel_init=default_init(), - ) + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + up=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) self.fir = fir self.with_conv = with_conv self.fir_kernel = fir_kernel @@ -438,15 +434,10 @@ class Upsample(nn.Module): def forward(self, x): B, C, H, W = x.shape - if not self.fir: - h = F.interpolate(x, (H * 2, W * 2), "nearest") - if self.with_conv: - h = self.Conv_0(h) + if not self.with_conv: + h = upsample_2d(x, self.fir_kernel, factor=2) else: - if not self.with_conv: - h = upsample_2d(x, self.fir_kernel, factor=2) - else: - h = self.Conv2d_0(x) + h = self.Conv2d_0(x) return h @@ -455,20 +446,16 @@ class Downsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch - if not fir: - if with_conv: - self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) - else: - if with_conv: - self.Conv2d_0 = Conv2d( - in_ch, - out_ch, - kernel=3, - down=True, - resample_kernel=fir_kernel, - use_bias=True, - kernel_init=default_init(), - ) + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + down=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) self.fir = fir self.fir_kernel = fir_kernel self.with_conv = with_conv @@ -476,17 +463,10 @@ class Downsample(nn.Module): def forward(self, x): B, C, H, W = x.shape - if not self.fir: - if self.with_conv: - x = F.pad(x, (0, 1, 0, 1)) - x = self.Conv_0(x) - else: - x = F.avg_pool2d(x, 2, stride=2) + if not self.with_conv: + x = downsample_2d(x, self.fir_kernel, factor=2) else: - if not self.with_conv: - x = downsample_2d(x, self.fir_kernel, factor=2) - else: - x = self.Conv2d_0(x) + x = self.Conv2d_0(x) return x