add RL test, remove conds from RL model input
This commit is contained in:
parent
a2b72faff7
commit
3a5c87055c
|
@ -195,7 +195,7 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||||
nn.Conv1d(dim, transition_dim, 1),
|
nn.Conv1d(dim, transition_dim, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, cond, time):
|
def forward(self, x, time):
|
||||||
"""
|
"""
|
||||||
x : [ batch x horizon x transition ]
|
x : [ batch x horizon x transition ]
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -40,6 +40,7 @@ from diffusers import (
|
||||||
ScoreSdeVeScheduler,
|
ScoreSdeVeScheduler,
|
||||||
ScoreSdeVpPipeline,
|
ScoreSdeVpPipeline,
|
||||||
ScoreSdeVpScheduler,
|
ScoreSdeVpScheduler,
|
||||||
|
TemporalUNet,
|
||||||
UNetGradTTSModel,
|
UNetGradTTSModel,
|
||||||
UNetLDMModel,
|
UNetLDMModel,
|
||||||
UNetModel,
|
UNetModel,
|
||||||
|
@ -606,6 +607,46 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
model_class = TemporalUNet
|
||||||
|
|
||||||
|
def test_from_pretrained_hub(self):
|
||||||
|
model, loading_info = TemporalUNet.from_pretrained(
|
||||||
|
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
image = model(**self.dummy_input)
|
||||||
|
|
||||||
|
assert image is not None, "Make sure output is not None"
|
||||||
|
|
||||||
|
def test_output_pretrained(self):
|
||||||
|
model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
|
||||||
|
num_features = model.transition_dim
|
||||||
|
seq_len = 16
|
||||||
|
noise = torch.randn((1, seq_len, num_features))
|
||||||
|
time_step = torch.full((num_features,), 0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(noise, time_step)
|
||||||
|
|
||||||
|
output_slice = output[0, -3:, -3:].flatten()
|
||||||
|
# fmt: off
|
||||||
|
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580,
|
||||||
|
-0.0584])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
model_class = NCSNpp
|
model_class = NCSNpp
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue