add test for VQModel

This commit is contained in:
patil-suraj 2022-06-29 12:34:24 +02:00
parent 0b7daa6de9
commit bae04ea9d8
1 changed files with 78 additions and 0 deletions

View File

@ -22,6 +22,7 @@ import numpy as np
import torch
from diffusers import (
AutoencoderKL,
BDDMPipeline,
DDIMPipeline,
DDIMScheduler,
@ -44,6 +45,7 @@ from diffusers import (
UNetGradTTSModel,
UNetLDMModel,
UNetModel,
VQModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
@ -805,6 +807,82 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class VQModelTests(ModelTesterMixin, unittest.TestCase):
model_class = VQModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"x": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"ch": 64,
"out_ch": 3,
"num_res_blocks": 1,
"attn_resolutions": [],
"in_channels": 3,
"resolution": 32,
"z_channels": 3,
"n_embed": 256,
"embed_dim": 3,
"sane_index_shape": False,
"ch_mult": (1,),
"dropout": 0.0,
"double_z": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_forward_signature(self):
pass
def test_training(self):
pass
def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
with torch.no_grad():
output = model(image)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462,
-0.4218])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models