From c119dc4c045506b53b2eb893043d1f774fa3e68e Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 6 Oct 2022 14:01:45 +0200 Subject: [PATCH] allow multiple generations per prompt (#741) * compute text embeds per prompt * don't repeat uncond prompts * repeat separatly * update image2image * fix repeat uncond embeds * adapt inpaint pipeline * ifx uncond tokens in img2img * add tests and fix ucond embeds in im2img and inpaint pipe --- .../pipeline_stable_diffusion.py | 15 +- .../pipeline_stable_diffusion_img2img.py | 17 +- .../pipeline_stable_diffusion_inpaint.py | 21 +- tests/test_pipelines.py | 196 ++++++++++++++++++ 4 files changed, 236 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index cddba1cd..00e72de6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -117,6 +117,7 @@ class StableDiffusionPipeline(DiffusionPipeline): num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, @@ -148,6 +149,8 @@ class StableDiffusionPipeline(DiffusionPipeline): negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -215,6 +218,9 @@ class StableDiffusionPipeline(DiffusionPipeline): text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -223,14 +229,14 @@ class StableDiffusionPipeline(DiffusionPipeline): if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] * batch_size + uncond_tokens = [""] elif type(prompt) is not type(negative_prompt): raise TypeError( "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" " {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -250,6 +256,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes @@ -260,7 +269,7 @@ class StableDiffusionPipeline(DiffusionPipeline): # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 4d706f25..15bdd020 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -129,6 +129,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", @@ -164,6 +165,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -220,7 +223,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_latents = 0.18215 * init_latents # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) @@ -228,7 +231,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -252,6 +255,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -260,14 +266,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] * batch_size + uncond_tokens = [""] elif type(prompt) is not type(negative_prompt): raise TypeError( "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" " {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError("The length of `negative_prompt` should be equal to batch_size.") else: @@ -283,6 +289,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 24dba21a..24f4bc99 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -145,6 +145,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", @@ -184,6 +185,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -242,15 +245,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): init_latents = 0.18215 * init_latents - # Expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) init_latents_orig = init_latents # preprocess mask if not isinstance(mask_image, torch.FloatTensor): mask_image = preprocess_mask(mask_image) mask_image = mask_image.to(self.device) - mask = torch.cat([mask_image] * batch_size) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) # check sizes if not mask.shape == init_latents.shape: @@ -262,7 +265,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -286,6 +289,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -294,14 +300,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] * batch_size + uncond_tokens = [""] elif type(prompt) is not type(negative_prompt): raise TypeError( "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" " {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -321,6 +327,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 83097962..e6cc37ad 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -750,6 +750,202 @@ class PipelineFastTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_num_images_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images + + assert images.shape == (1, 128, 128, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images + + assert images.shape == (batch_size, 128, 128, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt + ).images + + assert images.shape == (num_images_per_prompt, 128, 128, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3) + + def test_stable_diffusion_img2img_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert images.shape == (1, 32, 32, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert images.shape == (batch_size, 32, 32, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 32, 32, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + + def test_stable_diffusion_inpaint_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert images.shape == (1, 32, 32, 3) + + # test num_images_per_prompt=1 (default) for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert images.shape == (batch_size, 32, 32, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + prompt, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 32, 32, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + [prompt] * batch_size, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + class PipelineTesterMixin(unittest.TestCase): def tearDown(self):