Merge branch 'main' of github.com:huggingface/diffusers
This commit is contained in:
commit
62c2c547db
|
@ -31,6 +31,7 @@ from diffusers import (
|
|||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
GLIDESuperResUNetModel,
|
||||
GLIDETextToImageUNetModel,
|
||||
LatentDiffusion,
|
||||
PNDMScheduler,
|
||||
UNetGradTTSModel,
|
||||
|
@ -261,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
|||
sizes = (32, 32)
|
||||
low_res_size = (4, 4)
|
||||
|
||||
torch_device = "cpu"
|
||||
|
||||
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
|
||||
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
@ -343,6 +342,100 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = GLIDETextToImageUNetModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
transformer_dim = 32
|
||||
seq_len = 16
|
||||
|
||||
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
|
||||
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step, "transformer_out": emb}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (6, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"attention_resolutions": (2,),
|
||||
"channel_mult": (1, 2),
|
||||
"in_channels": 3,
|
||||
"out_channels": 6,
|
||||
"model_channels": 32,
|
||||
"num_head_channels": 8,
|
||||
"num_heads_upsample": 1,
|
||||
"num_res_blocks": 2,
|
||||
"resblock_updown": True,
|
||||
"resolution": 32,
|
||||
"use_scale_shift_norm": True,
|
||||
"transformer_dim": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_output(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["x"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = GLIDETextToImageUNetModel.from_pretrained(
|
||||
"fusing/unet-glide-text2im-dummy", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = GLIDETextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to(
|
||||
torch_device
|
||||
)
|
||||
emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, emb)
|
||||
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([ 2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetLDMModel
|
||||
|
||||
|
|
Loading…
Reference in New Issue