Test ResnetBlock2D (#1850)

* test resnet block

* fix code format required by isort

* add torch device

* nit
This commit is contained in:
Erin 2023-01-04 13:57:32 -08:00 committed by GitHub
parent cb8a3dbe34
commit 9e17983d9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 93 additions and 1 deletions

94
tests/test_layers_utils.py Executable file → Normal file
View File

@ -22,7 +22,7 @@ from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils import torch_device from diffusers.utils import torch_device
@ -222,6 +222,98 @@ class Downsample2DBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class ResnetBlock2DTests(unittest.TestCase):
def test_resnet_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 64, 64)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_use_in_shortcut(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 64, 64)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_resnet_up(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 128, 128)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_resnet_down(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_kernel_fir(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_restnet_with_kernel_sde_vp(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
temb = torch.randn(1, 128).to(torch_device)
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device)
with torch.no_grad():
output_tensor = resnet_block(sample, temb)
assert output_tensor.shape == (1, 32, 32, 32)
output_slice = output_tensor[0, -1, -3:, -3:]
expected_slice = torch.tensor(
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class AttentionBlockTests(unittest.TestCase): class AttentionBlockTests(unittest.TestCase):
@unittest.skipIf( @unittest.skipIf(
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039" torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"