allow multiple generations per prompt (#741)

* compute text embeds per prompt

* don't repeat uncond prompts

* repeat separatly

* update image2image

* fix repeat uncond embeds

* adapt inpaint pipeline

* ifx uncond tokens in img2img

* add tests and fix ucond embeds in im2img and inpaint pipe
This commit is contained in:
Suraj Patil 2022-10-06 14:01:45 +02:00 committed by GitHub
parent 367a671a06
commit c119dc4c04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 236 additions and 13 deletions

View File

@ -117,6 +117,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
@ -148,6 +149,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@ -215,6 +218,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@ -223,14 +229,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@ -250,6 +256,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
@ -260,7 +269,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":

View File

@ -129,6 +129,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
@ -164,6 +165,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@ -220,7 +223,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
@ -228,7 +231,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@ -252,6 +255,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@ -260,14 +266,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
else:
@ -283,6 +289,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

View File

@ -145,6 +145,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
@ -184,6 +185,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
@ -242,15 +245,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size)
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
# check sizes
if not mask.shape == init_latents.shape:
@ -262,7 +265,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@ -286,6 +289,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@ -294,14 +300,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
uncond_tokens = [""]
elif type(prompt) is not type(negative_prompt):
raise TypeError(
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@ -321,6 +327,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes

View File

@ -750,6 +750,202 @@ class PipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# test num_images_per_prompt=1 (default)
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
assert images.shape == (1, 128, 128, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
assert images.shape == (batch_size, 128, 128, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
images = sd_pipe(
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
).images
assert images.shape == (num_images_per_prompt, 128, 128, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
).images
assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3)
def test_stable_diffusion_img2img_num_images_per_prompt(self):
device = "cpu"
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
init_image = self.dummy_image.to(device)
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# test num_images_per_prompt=1 (default)
images = sd_pipe(
prompt,
num_inference_steps=2,
output_type="np",
init_image=init_image,
).images
assert images.shape == (1, 32, 32, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size,
num_inference_steps=2,
output_type="np",
init_image=init_image,
).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
images = sd_pipe(
prompt,
num_inference_steps=2,
output_type="np",
init_image=init_image,
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size,
num_inference_steps=2,
output_type="np",
init_image=init_image,
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
def test_stable_diffusion_inpaint_num_images_per_prompt(self):
device = "cpu"
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# test num_images_per_prompt=1 (default)
images = sd_pipe(
prompt,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
).images
assert images.shape == (1, 32, 32, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
images = sd_pipe(
prompt,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
num_images_per_prompt=num_images_per_prompt,
).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):