add RL test, remove conds from RL model input

This commit is contained in:
Nathan Lambert 2022-06-27 14:48:15 -04:00
parent a2b72faff7
commit 3a5c87055c
No known key found for this signature in database
GPG Key ID: D667B1D408FF16AF
2 changed files with 42 additions and 1 deletions

View File

@ -195,7 +195,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
nn.Conv1d(dim, transition_dim, 1), nn.Conv1d(dim, transition_dim, 1),
) )
def forward(self, x, cond, time): def forward(self, x, time):
""" """
x : [ batch x horizon x transition ] x : [ batch x horizon x transition ]
""" """

View File

@ -40,6 +40,7 @@ from diffusers import (
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
ScoreSdeVpPipeline, ScoreSdeVpPipeline,
ScoreSdeVpScheduler, ScoreSdeVpScheduler,
TemporalUNet,
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel, UNetLDMModel,
UNetModel, UNetModel,
@ -606,6 +607,46 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = TemporalUNet
def test_from_pretrained_hub(self):
model, loading_info = TemporalUNet.from_pretrained(
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_features = model.transition_dim
seq_len = 16
noise = torch.randn((1, seq_len, num_features))
time_step = torch.full((num_features,), 0)
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580,
-0.0584])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = NCSNpp model_class = NCSNpp