diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 9c0c7713..56264128 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -195,7 +195,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): nn.Conv1d(dim, transition_dim, 1), ) - def forward(self, x, cond, time): + def forward(self, x, time): """ x : [ batch x horizon x transition ] """ diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index be7f61da..70e3ec3b 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -40,6 +40,7 @@ from diffusers import ( ScoreSdeVeScheduler, ScoreSdeVpPipeline, ScoreSdeVpScheduler, + TemporalUNet, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -606,6 +607,46 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = TemporalUNet + + def test_from_pretrained_hub(self): + model, loading_info = TemporalUNet.from_pretrained( + "fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.transition_dim + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)) + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, + -0.0584]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = NCSNpp