allow multiple generations per prompt (#741)
* compute text embeds per prompt * don't repeat uncond prompts * repeat separatly * update image2image * fix repeat uncond embeds * adapt inpaint pipeline * ifx uncond tokens in img2img * add tests and fix ucond embeds in im2img and inpaint pipe
This commit is contained in:
parent
367a671a06
commit
c119dc4c04
|
@ -117,6 +117,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: float = 0.0,
|
eta: float = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -148,6 +149,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||||
if `guidance_scale` is less than `1`).
|
if `guidance_scale` is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
@ -215,6 +218,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt
|
||||||
|
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
|
@ -223,14 +229,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
uncond_tokens: List[str]
|
uncond_tokens: List[str]
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
uncond_tokens = [""] * batch_size
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt] * batch_size
|
uncond_tokens = [negative_prompt]
|
||||||
elif batch_size != len(negative_prompt):
|
elif batch_size != len(negative_prompt):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
@ -250,6 +256,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
|
@ -260,7 +269,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
# Unlike in other pipelines, latents need to be generated in the target device
|
# Unlike in other pipelines, latents need to be generated in the target device
|
||||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||||
# However this currently doesn't work in `mps`.
|
# However this currently doesn't work in `mps`.
|
||||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||||
latents_dtype = text_embeddings.dtype
|
latents_dtype = text_embeddings.dtype
|
||||||
if latents is None:
|
if latents is None:
|
||||||
if self.device.type == "mps":
|
if self.device.type == "mps":
|
||||||
|
|
|
@ -129,6 +129,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
guidance_scale: Optional[float] = 7.5,
|
guidance_scale: Optional[float] = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: Optional[float] = 0.0,
|
eta: Optional[float] = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
|
@ -164,6 +165,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||||
if `guidance_scale` is less than `1`).
|
if `guidance_scale` is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
@ -220,7 +223,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
# expand init_latents for batch_size
|
# expand init_latents for batch_size
|
||||||
init_latents = torch.cat([init_latents] * batch_size)
|
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# get the original timestep using init_timestep
|
# get the original timestep using init_timestep
|
||||||
offset = self.scheduler.config.get("steps_offset", 0)
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
@ -228,7 +231,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
init_timestep = min(init_timestep, num_inference_steps)
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||||
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
|
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||||
|
|
||||||
# add noise to latents using the timesteps
|
# add noise to latents using the timesteps
|
||||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||||
|
@ -252,6 +255,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt
|
||||||
|
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
|
@ -260,14 +266,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
uncond_tokens: List[str]
|
uncond_tokens: List[str]
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
uncond_tokens = [""] * batch_size
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt] * batch_size
|
uncond_tokens = [negative_prompt]
|
||||||
elif batch_size != len(negative_prompt):
|
elif batch_size != len(negative_prompt):
|
||||||
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
|
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
|
||||||
else:
|
else:
|
||||||
|
@ -283,6 +289,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
|
|
|
@ -145,6 +145,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
guidance_scale: Optional[float] = 7.5,
|
guidance_scale: Optional[float] = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
eta: Optional[float] = 0.0,
|
eta: Optional[float] = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
|
@ -184,6 +185,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||||
if `guidance_scale` is less than `1`).
|
if `guidance_scale` is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
@ -242,15 +245,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
# Expand init_latents for batch_size
|
# Expand init_latents for batch_size and num_images_per_prompt
|
||||||
init_latents = torch.cat([init_latents] * batch_size)
|
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||||
init_latents_orig = init_latents
|
init_latents_orig = init_latents
|
||||||
|
|
||||||
# preprocess mask
|
# preprocess mask
|
||||||
if not isinstance(mask_image, torch.FloatTensor):
|
if not isinstance(mask_image, torch.FloatTensor):
|
||||||
mask_image = preprocess_mask(mask_image)
|
mask_image = preprocess_mask(mask_image)
|
||||||
mask_image = mask_image.to(self.device)
|
mask_image = mask_image.to(self.device)
|
||||||
mask = torch.cat([mask_image] * batch_size)
|
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
|
||||||
|
|
||||||
# check sizes
|
# check sizes
|
||||||
if not mask.shape == init_latents.shape:
|
if not mask.shape == init_latents.shape:
|
||||||
|
@ -262,7 +265,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
init_timestep = min(init_timestep, num_inference_steps)
|
init_timestep = min(init_timestep, num_inference_steps)
|
||||||
|
|
||||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||||
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
|
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
|
||||||
|
|
||||||
# add noise to latents using the timesteps
|
# add noise to latents using the timesteps
|
||||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||||
|
@ -286,6 +289,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate text embeddings for each generation per prompt
|
||||||
|
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
# corresponds to doing no classifier free guidance.
|
# corresponds to doing no classifier free guidance.
|
||||||
|
@ -294,14 +300,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
uncond_tokens: List[str]
|
uncond_tokens: List[str]
|
||||||
if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
uncond_tokens = [""] * batch_size
|
uncond_tokens = [""]
|
||||||
elif type(prompt) is not type(negative_prompt):
|
elif type(prompt) is not type(negative_prompt):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
" {type(prompt)}."
|
" {type(prompt)}."
|
||||||
)
|
)
|
||||||
elif isinstance(negative_prompt, str):
|
elif isinstance(negative_prompt, str):
|
||||||
uncond_tokens = [negative_prompt] * batch_size
|
uncond_tokens = [negative_prompt]
|
||||||
elif batch_size != len(negative_prompt):
|
elif batch_size != len(negative_prompt):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
@ -321,6 +327,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||||
|
|
||||||
|
# duplicate unconditional embeddings for each generation per prompt
|
||||||
|
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# For classifier free guidance, we need to do two forward passes.
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
# to avoid doing two forward passes
|
# to avoid doing two forward passes
|
||||||
|
|
|
@ -750,6 +750,202 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_num_images_per_prompt(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default)
|
||||||
|
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
|
||||||
|
|
||||||
|
assert images.shape == (1, 128, 128, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size, 128, 128, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for single prompt
|
||||||
|
num_images_per_prompt = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (num_images_per_prompt, 128, 128, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3)
|
||||||
|
|
||||||
|
def test_stable_diffusion_img2img_num_images_per_prompt(self):
|
||||||
|
device = "cpu"
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
init_image = self.dummy_image.to(device)
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default)
|
||||||
|
images = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
[prompt] * batch_size,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for single prompt
|
||||||
|
num_images_per_prompt = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
[prompt] * batch_size,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||||
|
|
||||||
|
def test_stable_diffusion_inpaint_num_images_per_prompt(self):
|
||||||
|
device = "cpu"
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||||
|
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionInpaintPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default)
|
||||||
|
images = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (1, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
[prompt] * batch_size,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for single prompt
|
||||||
|
num_images_per_prompt = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||||
|
|
||||||
|
# test num_images_per_prompt for batch of prompts
|
||||||
|
batch_size = 2
|
||||||
|
images = sd_pipe(
|
||||||
|
[prompt] * batch_size,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
mask_image=mask_image,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
).images
|
||||||
|
|
||||||
|
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||||
|
|
||||||
|
|
||||||
class PipelineTesterMixin(unittest.TestCase):
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
|
Loading…
Reference in New Issue