[Tests] Fix spatial transformer tests on GPU (#531)
This commit is contained in:
parent
c1796efd5f
commit
761f0297b0
|
@ -240,7 +240,9 @@ class AttentionBlockTests(unittest.TestCase):
|
|||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427])
|
||||
expected_slice = torch.tensor(
|
||||
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
|
@ -264,7 +266,9 @@ class SpatialTransformerTests(unittest.TestCase):
|
|||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201])
|
||||
expected_slice = torch.tensor(
|
||||
[-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_context_dim(self):
|
||||
|
@ -287,7 +291,9 @@ class SpatialTransformerTests(unittest.TestCase):
|
|||
assert attention_scores.shape == (1, 64, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471])
|
||||
expected_slice = torch.tensor(
|
||||
[-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_dropout(self):
|
||||
|
@ -313,5 +319,7 @@ class SpatialTransformerTests(unittest.TestCase):
|
|||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091])
|
||||
expected_slice = torch.tensor(
|
||||
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue