diffusers/tests/test_pipelines.py

1018 lines
38 KiB
Python
Raw Normal View History

# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import tempfile
import unittest
import numpy as np
import torch
import PIL
2022-08-30 10:43:42 -06:00
from datasets import load_dataset
from diffusers import (
AutoencoderKL,
DDIMPipeline,
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
KarrasVePipeline,
KarrasVeScheduler,
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
2022-08-30 10:43:42 -06:00
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
def test_progress_bar(capsys):
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert "10/10" in captured.err, "Progress bar has to be displayed"
ddpm.set_progress_bar_config(disable=True)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"
class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
@property
def dummy_cond_unet(self):
torch.manual_seed(0)
model = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
return model
@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
)
return model
@property
def dummy_vae(self):
torch.manual_seed(0)
model = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
chunk_size_feed_forward=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_safety_checker(self):
def check(images, *args, **kwargs):
return images, False
return check
@property
def dummy_extractor(self):
def extract(*args, **kwargs):
class Out:
def __init__(self):
self.pixel_values = torch.ones([0])
def to(self, device):
self.pixel_values.to(device)
return self
return Out()
return extract
def test_ddim(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_text2img(self):
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(tensor_format="pt")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler(tensor_format="pt")
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_uncond(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt")
vae = self.dummy_vq_model
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet
scheduler = KarrasVeScheduler(tensor_format="pt")
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.to(device).permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
schedular = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32"
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub_pass_model(self):
model_path = "google/ddpm-cifar10-32"
scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@slow
def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_lsun(self):
model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
sd_pipe.scheduler = scheduler
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
def test_score_sde_ve_pipeline(self):
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_uncond(self):
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddpm_ddim_equality(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = torch.manual_seed(0)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
def test_ddpm_ddim_equality_batched(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
generator = torch.manual_seed(0)
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
"sample"
]
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
@slow
def test_karras_ve_pipeline(self):
model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler(tensor_format="pt")
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
2022-08-30 10:43:42 -06:00
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[2]["image"].resize((768, 512))
2022-08-30 10:43:42 -06:00
output_image = ds[0]["image"].resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, use_auth_token=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
2022-08-30 10:43:42 -06:00
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline_k_lms(self):
2022-08-30 10:43:42 -06:00
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[2]["image"].resize((768, 512))
output_image = ds[1]["image"].resize((768, 512))
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, scheduler=lms, use_auth_token=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_in_paint_pipeline(self):
ds = load_dataset("hf-internal-testing/diffusers-images", split="train")
init_image = ds[3]["image"].resize((768, 512))
mask_image = ds[4]["image"].resize((768, 512))
output_image = ds[5]["image"].resize((768, 512))
2022-08-30 10:43:42 -06:00
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
2022-08-30 10:43:42 -06:00
prompt = "A red cat sitting on a parking bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
)["sample"][0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-3