From dc7c49e4e419ef0888647873b0fb2e233fea6dc2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 27 Jun 2022 15:50:54 +0200 Subject: [PATCH] add tests for upsample blocks --- src/diffusers/models/resnet.py | 14 ++++++---- tests/test_layers_utils.py | 51 ++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 04e3735d..2abb5ce6 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import torch.nn.functional as F @@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + def conv_transpose_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -73,7 +73,7 @@ class Upsample(nn.Module): self.use_conv_transpose = use_conv_transpose if use_conv_transpose: - self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1) + self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) @@ -81,15 +81,15 @@ class Upsample(nn.Module): assert x.shape[1] == self.channels if self.use_conv_transpose: return self.conv(x) - + if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2.0, mode="nearest") - + if self.use_conv: x = self.conv(x) - + return x @@ -138,6 +138,7 @@ class UNetUpsample(nn.Module): x = self.conv(x) return x + class GlideUpsample(nn.Module): """ An upsampling layer with an optional convolution. @@ -199,13 +200,14 @@ class LDMUpsample(nn.Module): class GradTTSUpsample(torch.nn.Module): def __init__(self, dim): - super(Upsample, self).__init__() + super(GradTTSUpsample, self).__init__() self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) def forward(self, x): return self.conv(x) +# TODO (patil-suraj): needs test class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 42a42610..cde7fc6b 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers.models.embeddings import get_timestep_embedding +from diffusers.models.resnet import Upsample from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase): torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]), 1e-3, ) + + +class UpsampleBlockTests(unittest.TestCase): + def test_upsample_default(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=False) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_conv(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=True) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_conv_out_dim(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=True, out_channels=64) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 64, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_upsample_with_transpose(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32) + upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046]) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)