add tests for sde ve vp models
This commit is contained in:
parent
4261c3aadf
commit
c9504bba10
|
@ -888,19 +888,19 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, x, time_cond, sigmas=None):
|
||||
def forward(self, x, timesteps, sigmas=None):
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
modules = self.all_modules
|
||||
m_idx = 0
|
||||
if self.embedding_type == "fourier":
|
||||
# Gaussian Fourier features embeddings.
|
||||
used_sigmas = time_cond
|
||||
used_sigmas = timesteps
|
||||
temb = modules[m_idx](torch.log(used_sigmas))
|
||||
m_idx += 1
|
||||
|
||||
elif self.embedding_type == "positional":
|
||||
# Sinusoidal positional embeddings.
|
||||
timesteps = time_cond
|
||||
timesteps = timesteps
|
||||
used_sigmas = sigmas
|
||||
temb = get_timestep_embedding(timesteps, self.nf)
|
||||
|
||||
|
|
|
@ -606,6 +606,133 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = NCSNpp
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"image_size": 32,
|
||||
"ch_mult": [1, 2, 2, 2],
|
||||
"nf": 32,
|
||||
"fir": True,
|
||||
"progressive": "output_skip",
|
||||
"progressive_combine": "sum",
|
||||
"progressive_input": "input_skip",
|
||||
"scale_by_sigma": True,
|
||||
"skip_rescale": True,
|
||||
"embedding_type": "fourier",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = NCSNpp.from_pretrained("fusing/cifar10-ncsnpp-ve", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained_ve_small(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ncsnpp-cifar10-ve-dummy")
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([3.1909e-07, -8.5393e-08, 4.8460e-07, -4.5550e-07, -1.3205e-06, -6.3475e-07, 9.7837e-07, 2.9974e-07, 1.2345e-06])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy")
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-8.3299e-07, -9.0431e-07, 4.0585e-08, 9.7563e-07, 1.0280e-06, 1.0133e-06, 1.4979e-06, -2.9716e-07, -6.1817e-07])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
def test_output_pretrained_vp(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ddpm-cifar10-vp-dummy")
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-3.9086e-07, -1.1001e-05, 1.8881e-06, 1.1106e-05, 1.6629e-06, 2.9820e-06, 8.4978e-06, 8.0253e-07, 1.5435e-06])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
|
|
Loading…
Reference in New Issue