Fix for InstructPix2PixPipeline to allow for prompt embeds to be passed in without prompts. (#2456)
* fix check inputs to allow prompt embeds in instruct pix2pix * linting * add reference comment to check inputs * remove comment * style changes --------- Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
parent
2ea1da89ab
commit
e4c356d3f6
|
@ -246,13 +246,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
# 0. Check inputs
|
# 0. Check inputs
|
||||||
self.check_inputs(prompt, callback_steps)
|
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
raise ValueError("`image` input cannot be undefined.")
|
raise ValueError("`image` input cannot be undefined.")
|
||||||
|
|
||||||
# 1. Define call parameters
|
# 1. Define call parameters
|
||||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
@ -640,10 +646,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def check_inputs(self, prompt, callback_steps):
|
def check_inputs(
|
||||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
):
|
||||||
|
|
||||||
if (callback_steps is None) or (
|
if (callback_steps is None) or (
|
||||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||||
):
|
):
|
||||||
|
@ -652,6 +657,32 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||||
f" {type(callback_steps)}."
|
f" {type(callback_steps)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||||
|
" only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is None and prompt_embeds is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||||
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||||
|
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||||
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||||
|
f" {negative_prompt_embeds.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||||
|
|
Loading…
Reference in New Issue