add UNetLDMModelTests

This commit is contained in:
patil-suraj 2022-06-20 14:45:58 +02:00
parent 646e16fe06
commit 95a45f5b3a
1 changed files with 69 additions and 0 deletions

View File

@ -34,6 +34,7 @@ from diffusers import (
LatentDiffusion,
PNDMScheduler,
UNetModel,
UNetLDMModel,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
@ -341,6 +342,74 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetLDMModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"x": noise, "timesteps": time_step}
@property
def get_input_shape(self):
return (4, 32, 32)
@property
def get_output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"image_size": 32,
"in_channels": 4,
"out_channels": 4,
"model_channels": 32,
"num_res_blocks": 2,
"attention_resolutions": (16,),
"channel_mult": (1, 2),
"num_heads": 2,
"conv_resample": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):