[Tests] Test attention.py (#368)
* add test for AttentionBlock, SpatialTransformer * add context_dim, handle device * removed dropout test * fixes, add dropout test
This commit is contained in:
parent
37c9d789aa
commit
f73ca908e5
|
@ -19,8 +19,10 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from diffusers.models.attention import AttentionBlock, SpatialTransformer
|
||||||
from diffusers.models.embeddings import get_timestep_embedding
|
from diffusers.models.embeddings import get_timestep_embedding
|
||||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||||
|
from diffusers.testing_utils import torch_device
|
||||||
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
@ -216,3 +218,100 @@ class Downsample2DBlockTests(unittest.TestCase):
|
||||||
output_slice = downsampled[0, -1, -3:, -3:]
|
output_slice = downsampled[0, -1, -3:, -3:]
|
||||||
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
|
expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522])
|
||||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlockTests(unittest.TestCase):
|
||||||
|
def test_attention_block_default(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
attentionBlock = AttentionBlock(
|
||||||
|
channels=32,
|
||||||
|
num_head_channels=1,
|
||||||
|
rescale_output_factor=1.0,
|
||||||
|
eps=1e-6,
|
||||||
|
num_groups=32,
|
||||||
|
).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
attention_scores = attentionBlock(sample)
|
||||||
|
|
||||||
|
assert attention_scores.shape == (1, 32, 64, 64)
|
||||||
|
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427])
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialTransformerTests(unittest.TestCase):
|
||||||
|
def test_spatial_transformer_default(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
spatial_transformer_block = SpatialTransformer(
|
||||||
|
in_channels=32,
|
||||||
|
n_heads=1,
|
||||||
|
d_head=32,
|
||||||
|
dropout=0.0,
|
||||||
|
context_dim=None,
|
||||||
|
).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
attention_scores = spatial_transformer_block(sample)
|
||||||
|
|
||||||
|
assert attention_scores.shape == (1, 32, 64, 64)
|
||||||
|
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201])
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_spatial_transformer_context_dim(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||||
|
spatial_transformer_block = SpatialTransformer(
|
||||||
|
in_channels=64,
|
||||||
|
n_heads=2,
|
||||||
|
d_head=32,
|
||||||
|
dropout=0.0,
|
||||||
|
context_dim=64,
|
||||||
|
).to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
context = torch.randn(1, 4, 64).to(torch_device)
|
||||||
|
attention_scores = spatial_transformer_block(sample, context)
|
||||||
|
|
||||||
|
assert attention_scores.shape == (1, 64, 64, 64)
|
||||||
|
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471])
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
||||||
|
def test_spatial_transformer_dropout(self):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||||
|
spatial_transformer_block = (
|
||||||
|
SpatialTransformer(
|
||||||
|
in_channels=32,
|
||||||
|
n_heads=2,
|
||||||
|
d_head=16,
|
||||||
|
dropout=0.3,
|
||||||
|
context_dim=None,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
attention_scores = spatial_transformer_block(sample)
|
||||||
|
|
||||||
|
assert attention_scores.shape == (1, 32, 64, 64)
|
||||||
|
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||||
|
|
||||||
|
expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091])
|
||||||
|
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||||
|
|
Loading…
Reference in New Issue