add tests for upsample blocks

This commit is contained in:
patil-suraj 2022-06-27 15:50:54 +02:00
parent e13ee8b5b3
commit dc7c49e4e4
2 changed files with 59 additions and 6 deletions

View File

@ -1,4 +1,3 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_transpose_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@ -73,7 +73,7 @@ class Upsample(nn.Module):
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x = self.conv(x)
return x
class GlideUpsample(nn.Module):
"""
An upsampling layer with an optional convolution.
@ -199,13 +200,14 @@ class LDMUpsample(nn.Module):
class GradTTSUpsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
super(GradTTSUpsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
# TODO (patil-suraj): needs test
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()

View File

@ -22,6 +22,7 @@ import numpy as np
import torch
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Upsample
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
1e-3,
)
class UpsampleBlockTests(unittest.TestCase):
def test_upsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 64, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_transpose(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)