add RL test, remove conds from RL model input
This commit is contained in:
parent
a2b72faff7
commit
3a5c87055c
|
@ -195,7 +195,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
|||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, cond, time):
|
||||
def forward(self, x, time):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
|
|
|
@ -40,6 +40,7 @@ from diffusers import (
|
|||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpPipeline,
|
||||
ScoreSdeVpScheduler,
|
||||
TemporalUNet,
|
||||
UNetGradTTSModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
|
@ -606,6 +607,46 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = TemporalUNet
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = TemporalUNet.from_pretrained(
|
||||
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.transition_dim
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, seq_len, num_features))
|
||||
time_step = torch.full((num_features,), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580,
|
||||
-0.0584])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = NCSNpp
|
||||
|
||||
|
|
Loading…
Reference in New Issue