From b17d49f8631e305d56ec991d2bffb306c690fa7f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 19 Sep 2022 15:52:52 +0200 Subject: [PATCH] Fix `_upsample_2d` (#535) * Fix _upsample_2d Co-authored-by: ydshieh --- src/diffusers/models/resnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 55ae42c2..97f3c02a 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -149,7 +149,6 @@ class FirUpsample2D(nn.Module): stride = (factor, factor) # Determine data dimensions. - stride = [1, 1, factor, factor] output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) output_padding = ( output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, @@ -161,7 +160,7 @@ class FirUpsample2D(nn.Module): # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)