This commit is contained in:
Patrick von Platen 2022-06-30 14:54:31 +00:00
parent c54f36f087
commit 3dbd6a8f4d
2 changed files with 13 additions and 12 deletions

View File

@ -207,6 +207,9 @@ class ResBlock(TimestepBlock):
self.updown = up or down
# if self.updown:
# import ipdb; ipdb.set_trace()
if up:
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)

View File

@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@ -753,7 +753,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([3.1909e-07, -8.5393e-08, 4.8460e-07, -4.5550e-07, -1.3205e-06, -6.3475e-07, 9.7837e-07, 2.9974e-07, 1.2345e-06])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
def test_output_pretrained_ve_large(self):
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
@ -779,7 +779,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-8.3299e-07, -9.0431e-07, 4.0585e-08, 9.7563e-07, 1.0280e-06, 1.0133e-06, 1.4979e-06, -2.9716e-07, -6.1817e-07])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
def test_output_pretrained_vp(self):
model = NCSNpp.from_pretrained("fusing/ddpm-cifar10-vp-dummy")
@ -805,7 +805,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-3.9086e-07, -1.1001e-05, 1.8881e-06, 1.1106e-05, 1.6629e-06, 2.9820e-06, 8.4978e-06, 8.0253e-07, 1.5435e-06])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class VQModelTests(ModelTesterMixin, unittest.TestCase):
@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462,
-0.4218])
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662,
0.1750])
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class PipelineTesterMixin(unittest.TestCase):