Fix for InstructPix2PixPipeline to allow for prompt embeds to be passed in without prompts. (#2456)

* fix check inputs to allow prompt embeds in instruct pix2pix

* linting

* add reference comment to check inputs

* remove comment

* style changes

---------

Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
Dhruv Nair 2023-03-06 01:12:50 +05:30 committed by GitHub
parent 2ea1da89ab
commit e4c356d3f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 6 deletions

View File

@ -246,13 +246,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Check inputs # 0. Check inputs
self.check_inputs(prompt, callback_steps) self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
# 1. Define call parameters # 1. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@ -640,10 +646,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
def check_inputs(self, prompt, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") ):
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
): ):
@ -652,6 +657,32 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)