add tests for upsample blocks
This commit is contained in:
parent
e13ee8b5b3
commit
dc7c49e4e4
|
@ -1,4 +1,3 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
|
|||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_transpose_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
|
@ -73,7 +73,7 @@ class Upsample(nn.Module):
|
|||
self.use_conv_transpose = use_conv_transpose
|
||||
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
|
||||
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
|
@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
|
|||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GlideUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
@ -199,13 +200,14 @@ class LDMUpsample(nn.Module):
|
|||
|
||||
class GradTTSUpsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Upsample, self).__init__()
|
||||
super(GradTTSUpsample, self).__init__()
|
||||
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# TODO (patil-suraj): needs test
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
|
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Upsample
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
|
@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
|
|||
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
|
||||
1e-3,
|
||||
)
|
||||
|
||||
|
||||
class UpsampleBlockTests(unittest.TestCase):
|
||||
def test_upsample_default(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_conv_out_dim(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 64, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_upsample_with_transpose(self):
|
||||
torch.manual_seed(0)
|
||||
sample = torch.randn(1, 32, 32, 32)
|
||||
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
|
||||
with torch.no_grad():
|
||||
upsampled = upsample(sample)
|
||||
|
||||
assert upsampled.shape == (1, 32, 64, 64)
|
||||
output_slice = upsampled[0, -1, -3:, -3:]
|
||||
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue