[Img2Img] Fix batch size mismatch prompts vs. init images (#793)
* [Img2Img] Fix batch size mismatch prompts vs. init images * Remove bogus folder * fix * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
c1b6ea3dce
commit
6bc11782b7
|
@ -195,6 +195,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
"""
|
"""
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
prompt = [prompt]
|
||||||
elif isinstance(prompt, list):
|
elif isinstance(prompt, list):
|
||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
else:
|
else:
|
||||||
|
@ -284,8 +285,23 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
init_latents = 0.18215 * init_latents
|
init_latents = 0.18215 * init_latents
|
||||||
|
|
||||||
# expand init_latents for batch_size
|
if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
|
||||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
# expand init_latents for batch_size
|
||||||
|
deprecation_message = (
|
||||||
|
f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||||
|
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||||
|
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||||
|
" your script to pass as many init images as text prompts to suppress this warning."
|
||||||
|
)
|
||||||
|
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
additional_image_per_prompt = len(prompt) // init_latents.shape[0]
|
||||||
|
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
|
||||||
|
elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||||
|
|
||||||
# get the original timestep using init_timestep
|
# get the original timestep using init_timestep
|
||||||
offset = self.scheduler.config.get("steps_offset", 0)
|
offset = self.scheduler.config.get("steps_offset", 0)
|
||||||
|
|
|
@ -698,6 +698,48 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_img2img_multiple_init_images(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1)
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionImg2ImgPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=self.dummy_safety_checker,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = 2 * ["A painting of a squirrel eating a burger"]
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe(
|
||||||
|
prompt,
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
init_image=init_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = output.images
|
||||||
|
|
||||||
|
image_slice = image[-1, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (2, 32, 32, 3)
|
||||||
|
expected_slice = np.array([0.5144, 0.4447, 0.4735, 0.6676, 0.5526, 0.5454, 0.645, 0.5149, 0.4689])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_k_lms(self):
|
def test_stable_diffusion_img2img_k_lms(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
unet = self.dummy_cond_unet
|
unet = self.dummy_cond_unet
|
||||||
|
|
Loading…
Reference in New Issue