From 9e17983d9f8bd1865f6426a55e7a048c4c50ce84 Mon Sep 17 00:00:00 2001 From: Erin <14718778+hchings@users.noreply.github.com> Date: Wed, 4 Jan 2023 13:57:32 -0800 Subject: [PATCH] Test ResnetBlock2D (#1850) * test resnet block * fix code format required by isort * add torch device * nit --- tests/test_layers_utils.py | 94 +++++++++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 1 deletion(-) mode change 100755 => 100644 tests/test_layers_utils.py diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py old mode 100755 new mode 100644 index 43bf05f6..344c4721 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ from torch import nn from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.resnet import Downsample2D, Upsample2D +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.utils import torch_device @@ -222,6 +222,98 @@ class Downsample2DBlockTests(unittest.TestCase): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) +class ResnetBlock2DTests(unittest.TestCase): + def test_resnet_default(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 64, 64) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_restnet_with_use_in_shortcut(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 64, 64) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_resnet_up(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 128, 128) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_resnet_down(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 32, 32) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_restnet_with_kernel_fir(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 32, 32) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_restnet_with_kernel_sde_vp(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 64, 64).to(torch_device) + temb = torch.randn(1, 128).to(torch_device) + resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device) + with torch.no_grad(): + output_tensor = resnet_block(sample, temb) + + assert output_tensor.shape == (1, 32, 32, 32) + output_slice = output_tensor[0, -1, -3:, -3:] + expected_slice = torch.tensor( + [-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + class AttentionBlockTests(unittest.TestCase): @unittest.skipIf( torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"