This commit is contained in:
parent
c54f36f087
commit
3dbd6a8f4d
|
@ -207,6 +207,9 @@ class ResBlock(TimestepBlock):
|
|||
|
||||
self.updown = up or down
|
||||
|
||||
# if self.updown:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
|
||||
|
|
|
@ -259,7 +259,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
# fmt: off
|
||||
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
@ -607,7 +607,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
@ -678,7 +678,7 @@ class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
@ -753,7 +753,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_output_slice = torch.tensor([3.1909e-07, -8.5393e-08, 4.8460e-07, -4.5550e-07, -1.3205e-06, -6.3475e-07, 9.7837e-07, 2.9974e-07, 1.2345e-06])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
|
||||
|
@ -779,7 +779,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_output_slice = torch.tensor([-8.3299e-07, -9.0431e-07, 4.0585e-08, 9.7563e-07, 1.0280e-06, 1.0133e-06, 1.4979e-06, -2.9716e-07, -6.1817e-07])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_output_pretrained_vp(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ddpm-cifar10-vp-dummy")
|
||||
|
@ -805,7 +805,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_output_slice = torch.tensor([-3.9086e-07, -1.1001e-05, 1.8881e-06, 1.1106e-05, 1.6629e-06, 2.9820e-06, 8.4978e-06, 8.0253e-07, 1.5435e-06])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
@ -878,10 +878,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462,
|
||||
-0.4218])
|
||||
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, -0.4218])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
@ -950,10 +949,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662,
|
||||
0.1750])
|
||||
expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, 0.1750])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue