diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 72e15f4f..3e5ac4b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -195,6 +195,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): """ if isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif isinstance(prompt, list): batch_size = len(prompt) else: @@ -284,8 +285,23 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 30beb033..eecada22 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -698,6 +698,48 @@ class PipelineFastTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_img2img_multiple_init_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = 2 * ["A painting of a squirrel eating a burger"] + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + prompt, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 32, 32, 3) + expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_img2img_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet