# coding=utf-8 # Copyright 2022 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import os import random import tempfile import tracemalloc import unittest import numpy as np import torch import accelerate import PIL import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, KarrasVePipeline, KarrasVeScheduler, LDMPipeline, LDMTextToImagePipeline, OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionPipeline, PNDMPipeline, PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, VQModel, logging, ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir from packaging import version from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer torch.backends.cuda.matmul.allow_tf32 = False def test_progress_bar(capsys): model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) scheduler = DDPMScheduler(num_train_timesteps=10) ddpm = DDPMPipeline(model, scheduler).to(torch_device) ddpm(output_type="numpy").images captured = capsys.readouterr() assert "10/10" in captured.err, "Progress bar has to be displayed" ddpm.set_progress_bar_config(disable=True) ddpm(output_type="numpy").images captured = capsys.readouterr() assert captured.err == "", "Progress bar should be disabled" class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" ) # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24 assert pipeline.__class__.__name__ == "CustomPipeline" def test_run_custom_pipeline(self): pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" ) images, output_str = pipeline(num_inference_steps=2, output_type="np") assert images[0].shape == (1, 32, 32, 3) # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 assert output_str == "This is a test" def test_local_custom_pipeline(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path ) images, output_str = pipeline(num_inference_steps=2, output_type="np") assert pipeline.__class__.__name__ == "CustomLocalPipeline" assert images[0].shape == (1, 32, 32, 3) # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 assert output_str == "This is a local test" @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_load_pipeline_from_git(self): clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) pipeline = DiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", custom_pipeline="clip_guided_stable_diffusion", clip_model=clip_model, feature_extractor=feature_extractor, torch_dtype=torch.float16, revision="fp16", ) pipeline.enable_attention_slicing() pipeline = pipeline.to(torch_device) # NOTE that `"CLIPGuidedStableDiffusion"` is not a class that is defined in the pypi package of th e library, but solely on the community examples folder of GitHub under: # https://github.com/huggingface/diffusers/blob/main/examples/community/clip_guided_stable_diffusion.py assert pipeline.__class__.__name__ == "CLIPGuidedStableDiffusion" image = pipeline("a prompt", num_inference_steps=2, output_type="np").images[0] assert image.shape == (512, 512, 3) class PipelineFastTests(unittest.TestCase): def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() torch.cuda.empty_cache() @property def dummy_image(self): batch_size = 1 num_channels = 3 sizes = (32, 32) image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) return image @property def dummy_uncond_unet(self): torch.manual_seed(0) model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) return model @property def dummy_cond_unet(self): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, ) return model @property def dummy_cond_unet_inpaint(self): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=9, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, ) return model @property def dummy_vq_model(self): torch.manual_seed(0) model = VQModel( block_out_channels=[32, 64], in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], latent_channels=3, ) return model @property def dummy_vae(self): torch.manual_seed(0) model = AutoencoderKL( block_out_channels=[32, 64], in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], latent_channels=4, ) return model @property def dummy_text_encoder(self): torch.manual_seed(0) config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, num_hidden_layers=5, pad_token_id=1, vocab_size=1000, ) return CLIPTextModel(config) @property def dummy_extractor(self): def extract(*args, **kwargs): class Out: def __init__(self): self.pixel_values = torch.ones([0]) def to(self, device): self.pixel_values.to(device) return self return Out() return extract def test_ddim(self): unet = self.dummy_uncond_unet scheduler = DDIMScheduler() ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images generator = torch.manual_seed(0) image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_pndm_cifar10(self): unet = self.dummy_uncond_unet scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) pndm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images generator = torch.manual_seed(0) image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_ldm_text2img(self): unet = self.dummy_cond_unet scheduler = DDIMScheduler() vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" # Warmup pass when using mps (see #372) if torch_device == "mps": generator = torch.manual_seed(0) _ = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" ).images generator = torch.manual_seed(0) image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" ).images generator = torch.manual_seed(0) image_from_tuple = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy", return_dict=False, )[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_score_sde_ve_pipeline(self): unet = self.dummy_uncond_unet scheduler = ScoreSdeVeScheduler() sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) sde_ve.to(torch_device) sde_ve.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images generator = torch.manual_seed(0) image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[ 0 ] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_ldm_uncond(self): unet = self.dummy_uncond_unet scheduler = DDIMScheduler() vae = self.dummy_vq_model ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) # Warmup pass when using mps (see #372) if torch_device == "mps": generator = torch.manual_seed(0) _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images generator = torch.manual_seed(0) image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_karras_ve_pipeline(self): unet = self.dummy_uncond_unet scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images generator = torch.manual_seed(0) image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_components(self): """Test that components property works correctly""" unet = self.dummy_cond_unet scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) # make sure here that pndm scheduler skips prk inpaint = StableDiffusionInpaintPipelineLegacy( unet=unet, scheduler=scheduler, vae=vae, text_encoder=bert, tokenizer=tokenizer, safety_checker=None, feature_extractor=self.dummy_extractor, ).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) prompt = "A painting of a squirrel eating a burger" generator = torch.Generator(device=torch_device).manual_seed(0) image_inpaint = inpaint( [prompt], generator=generator, num_inference_steps=2, output_type="np", init_image=init_image, mask_image=mask_image, ).images image_img2img = img2img( [prompt], generator=generator, num_inference_steps=2, output_type="np", init_image=init_image, ).images image_text2img = text2img( [prompt], generator=generator, num_inference_steps=2, output_type="np", ).images assert image_inpaint.shape == (1, 32, 32, 3) assert image_img2img.shape == (1, 32, 32, 3) assert image_text2img.shape == (1, 128, 128, 3) class PipelineTesterMixin(unittest.TestCase): def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() torch.cuda.empty_cache() def test_smart_download(self): model_id = "hf-internal-testing/unet-pipeline-dummy" with tempfile.TemporaryDirectory() as tmpdirname: _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) local_repo_name = "--".join(["models"] + model_id.split("/")) snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) # inspect all downloaded files to make sure that everything is included assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) # let's make sure the super large numpy file: # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy # is not downloaded, but all the expected ones assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) def test_warning_unused_kwargs(self): model_id = "hf-internal-testing/unet-pipeline-dummy" logger = logging.get_logger("diffusers.pipeline_utils") with tempfile.TemporaryDirectory() as tmpdirname: with CaptureLogger(logger) as cap_logger: DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True) assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" def test_from_pretrained_save_pretrained(self): # 1. Load models model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) schedular = DDPMScheduler(num_train_timesteps=10) ddpm = DDPMPipeline(model, schedular) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) with tempfile.TemporaryDirectory() as tmpdirname: ddpm.save_pretrained(tmpdirname) new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) new_ddpm.to(torch_device) generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) new_image = new_ddpm(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @slow def test_from_pretrained_hub(self): model_path = "google/ddpm-cifar10-32" scheduler = DDPMScheduler(num_train_timesteps=10) ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub.to(torch_device) ddpm_from_hub.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) new_image = ddpm_from_hub(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @slow def test_from_pretrained_hub_pass_model(self): model_path = "google/ddpm-cifar10-32" scheduler = DDPMScheduler(num_train_timesteps=10) # pass unet into DiffusionPipeline unet = UNet2DModel.from_pretrained(model_path) ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler) ddpm_from_hub_custom_model.to(torch_device) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub.to(torch_device) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) new_image = ddpm_from_hub(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @slow def test_output_format(self): model_path = "google/ddpm-cifar10-32" pipe = DDIMPipeline.from_pretrained(model_path) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) images = pipe(generator=generator, output_type="numpy").images assert images.shape == (1, 32, 32, 3) assert isinstance(images, np.ndarray) images = pipe(generator=generator, output_type="pil").images assert isinstance(images, list) assert len(images) == 1 assert isinstance(images[0], PIL.Image.Image) # use PIL by default images = pipe(generator=generator).images assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) @slow def test_ddpm_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) scheduler = DDPMScheduler.from_config(model_id) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_lsun(self): model_id = "google/ddpm-ema-bedroom-256" unet = UNet2DModel.from_pretrained(model_id) scheduler = DDIMScheduler.from_config(model_id) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) scheduler = DDIMScheduler() ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ddim(generator=generator, eta=0.0, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_pndm_cifar10(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) scheduler = PNDMScheduler() pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) pndm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = pndm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img(self): ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" ).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img_fast(self): ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_score_sde_ve_pipeline(self): model_id = "google/ncsnpp-church-256" model = UNet2DModel.from_pretrained(model_id) scheduler = ScoreSdeVeScheduler.from_config(model_id) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) sde_ve.to(torch_device) sde_ve.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_uncond(self): ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddpm_ddim_equality(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) ddpm_scheduler = DDPMScheduler() ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) ddpm_image = ddpm(generator=generator, output_type="numpy").images generator = torch.manual_seed(0) ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_image - ddim_image).max() < 1e-1 @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation") def test_ddpm_ddim_equality_batched(self): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) ddpm_scheduler = DDPMScheduler() ddim_scheduler = DDIMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images generator = torch.manual_seed(0) ddim_images = ddim( batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy" ).images # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 @slow def test_karras_ve_pipeline(self): model_id = "google/ncsnpp-celebahq-256" model = UNet2DModel.from_pretrained(model_id) scheduler = KarrasVeScheduler() pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_stable_diffusion_onnx(self): sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) prompt = "A painting of a squirrel eating a burger" np.random.seed(0) output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np") image = output.images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow def test_stable_diffusion_img2img_onnx(self): init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/img2img/sketch-mountains-input.jpg" ) init_image = init_image.resize((768, 512)) pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None) prompt = "A fantasy landscape, trending on artstation" np.random.seed(0) output = pipe( prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, num_inference_steps=8, output_type="np", ) images = output.images image_slice = images[0, 255:258, 383:386, -1] assert images.shape == (1, 512, 768, 3) expected_slice = np.array([0.4830, 0.5242, 0.5603, 0.5016, 0.5131, 0.5111, 0.4928, 0.5025, 0.5055]) # TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2 @slow def test_stable_diffusion_inpaint_onnx(self): init_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo.png" ) mask_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None) prompt = "A red cat sitting on a park bench" np.random.seed(0) output = pipe( prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=7.5, num_inference_steps=8, output_type="np", ) images = output.images image_slice = images[0, 255:258, 255:258, -1] assert images.shape == (1, 512, 512, 3) expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow def test_stable_diffusion_onnx_intermediate_state(self): number_of_steps = 0 def test_callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: test_callback_fn.has_been_called = True nonlocal number_of_steps number_of_steps += 1 if step == 0: assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( [-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 elif step == 5: assert latents.shape == (1, 4, 64, 64) latents_slice = latents[0, -3:, -3:, -1] expected_slice = np.array( [-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264] ) assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 test_callback_fn.has_been_called = False pipe = OnnxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" ) pipe.set_progress_bar_config(disable=None) prompt = "Andromeda galaxy in a bottle" np.random.seed(0) pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1) assert test_callback_fn.has_been_called assert number_of_steps == 6 @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_accelerate_load_works(self): if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): return if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): return model_id = "CompVis/stable-diffusion-v1-4" _ = StableDiffusionPipeline.from_pretrained( model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" ).to(torch_device) @slow @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self): if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"): return if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"): return pipeline_id = "CompVis/stable-diffusion-v1-4" torch.cuda.empty_cache() gc.collect() tracemalloc.start() pipeline_normal_load = StableDiffusionPipeline.from_pretrained( pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True ) pipeline_normal_load.to(torch_device) _, peak_normal = tracemalloc.get_traced_memory() tracemalloc.stop() del pipeline_normal_load torch.cuda.empty_cache() gc.collect() tracemalloc.start() _ = StableDiffusionPipeline.from_pretrained( pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto" ) _, peak_accelerate = tracemalloc.get_traced_memory() tracemalloc.stop() assert peak_accelerate < peak_normal