Test ResnetBlock2D (#1850)
* test resnet block * fix code format required by isort * add torch device * nit
This commit is contained in:
parent
cb8a3dbe34
commit
9e17983d9f
|
@ -22,7 +22,7 @@ from torch import nn
|
||||||
|
|
||||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
|
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock
|
||||||
from diffusers.models.embeddings import get_timestep_embedding
|
from diffusers.models.embeddings import get_timestep_embedding
|
||||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||||
from diffusers.models.transformer_2d import Transformer2DModel
|
from diffusers.models.transformer_2d import Transformer2DModel
|
||||||
from diffusers.utils import torch_device
|
from diffusers.utils import torch_device
|
||||||
|
|
||||||
|
@ -222,6 +222,98 @@ class Downsample2DBlockTests(unittest.TestCase):
|
||||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock2DTests(unittest.TestCase):
|
||||||
|
def test_resnet_default(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 64, 64)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[-1.9010, -0.2974, -0.8245, -1.3533, 0.8742, -0.9645, -2.0584, 1.3387, -0.4746], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_restnet_with_use_in_shortcut(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, use_in_shortcut=True).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 64, 64)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[0.2226, -1.0791, -0.1629, 0.3659, -0.2889, -1.2376, 0.0582, 0.9206, 0.0044], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_resnet_up(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, up=True).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 128, 128)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[1.2130, -0.8753, -0.9027, 1.5783, -0.5362, -0.5001, 1.0726, -0.7732, -0.4182], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_resnet_down(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, down=True).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 32, 32)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_restnet_with_kernel_fir(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="fir", down=True).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 32, 32)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[-0.0934, -0.5729, 0.0909, -0.2710, -0.5044, 0.0243, -0.0665, -0.5267, -0.3136], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_restnet_with_kernel_sde_vp(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
temb = torch.randn(1, 128).to(torch_device)
|
||||||
|
resnet_block = ResnetBlock2D(in_channels=32, temb_channels=128, kernel="sde_vp", down=True).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output_tensor = resnet_block(sample, temb)
|
||||||
|
|
||||||
|
assert output_tensor.shape == (1, 32, 32, 32)
|
||||||
|
output_slice = output_tensor[0, -1, -3:, -3:]
|
||||||
|
expected_slice = torch.tensor(
|
||||||
|
[-0.3002, -0.7135, 0.1359, 0.0561, -0.7935, 0.0113, -0.1766, -0.6714, -0.0436], device=torch_device
|
||||||
|
)
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlockTests(unittest.TestCase):
|
class AttentionBlockTests(unittest.TestCase):
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
|
torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039"
|
||||||
|
|
Loading…
Reference in New Issue