move test num_images_per_prompt to pipeline mixin (#2488)

* attend and excite batch test causing timeouts

* move test num_images_per_prompt to pipeline mixin

* style

* prompt_key -> self.batch_params
This commit is contained in:
Will Berman 2023-03-03 11:45:07 -08:00 committed by GitHub
parent 2f489571a7
commit a72a057d62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 55 additions and 329 deletions

View File

@ -517,8 +517,30 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if (indices is None) or (indices is not None and not isinstance(indices, List)): indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
raise ValueError(f"`indices` has to be a list but is {type(indices)}") indices_is_list_list_ints = (
isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
)
if not indices_is_list_ints and not indices_is_list_list_ints:
raise TypeError("`indices` must be a list of ints or a list of a list of ints")
if indices_is_list_ints:
indices_batch_size = 1
elif indices_is_list_list_ints:
indices_batch_size = len(indices)
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
if indices_batch_size != prompt_batch_size:
raise ValueError(
f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
@ -675,7 +697,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
token_indices: List[int], token_indices: Union[List[int], List[List[int]]],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
@ -851,7 +873,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
if isinstance(token_indices[0], int): if isinstance(token_indices[0], int):
token_indices = [token_indices] token_indices = [token_indices]
indices = [] indices = []
for ind in token_indices: for ind in token_indices:
indices = indices + [ind] * num_images_per_prompt indices = indices + [ind] * num_images_per_prompt

View File

@ -160,19 +160,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert out_1.shape == (1, 64, 64, 3) assert out_1.shape == (1, 64, 64, 3)
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2 assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
def test_paint_by_example_inpaint_with_num_images_per_prompt(self):
device = "cpu"
pipe = PaintByExamplePipeline(**self.get_dummy_components())
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs()
images = pipe(**inputs, num_images_per_prompt=2).images
# check if the output is a list of 2 images
assert len(images) == 2
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -41,7 +41,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"negative_prompt_embeds", "negative_prompt_embeds",
} }
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)

View File

@ -477,43 +477,6 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# test num_images_per_prompt=1 (default)
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
assert images.shape == (1, 64, 64, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
assert images.shape == (batch_size, 64, 64, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
images = sd_pipe(
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
).images
assert images.shape == (num_images_per_prompt, 64, 64, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
images = sd_pipe(
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
def test_stable_diffusion_long_prompt(self): def test_stable_diffusion_long_prompt(self):
components = self.get_dummy_components() components = self.get_dummy_components()
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config) components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)

View File

@ -143,42 +143,6 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImageVariationPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 64, 64, 3)
# test num_images_per_prompt=1 (default) for batch of images
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 64, 64, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 64, 64, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -181,42 +181,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img2img_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 32, 32, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
@skip_mps @skip_mps
def test_save_load_local(self): def test_save_load_local(self):
return super().test_save_load_local() return super().test_save_load_local()

View File

@ -151,19 +151,6 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
assert out_pil.shape == (1, 64, 64, 3) assert out_pil.shape == (1, 64, 64, 3)
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2 assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=2).images
# check if the output is a list of 2 images
assert len(images) == 2
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -191,42 +191,6 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 32, 32, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -177,42 +177,6 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.Tes
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = sd_pipe(**inputs).images _ = sd_pipe(**inputs).images
def test_stable_diffusion_panorama_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 64, 64, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 64, 64, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 64, 64, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
@slow @slow
@require_torch_gpu @require_torch_gpu

View File

@ -191,34 +191,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 64, 64, 3)
# test num_images_per_prompt=2 for a single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 64, 64, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
# Non-determinism caused by the scheduler optimizing the latent inputs during inference # Non-determinism caused by the scheduler optimizing the latent inputs during inference
@unittest.skip("non-deterministic pipeline") @unittest.skip("non-deterministic pipeline")
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):

View File

@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
test_save_load_optional_components = False test_save_load_optional_components = False
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
@ -340,42 +340,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = StableDiffusionDepth2ImgPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = pipe(**inputs).images
assert images.shape == (1, 32, 32, 3)
# test num_images_per_prompt=1 (default) for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = pipe(**inputs).images
assert images.shape == (batch_size, 32, 32, 3)
# test num_images_per_prompt for single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 32, 32, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
def test_stable_diffusion_depth2img_pil(self): def test_stable_diffusion_depth2img_pil(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()

View File

@ -361,59 +361,6 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_image_variation_input_num_images_per_prompt(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
pipeline_inputs["image"] = [
pipeline_inputs["image"],
pipeline_inputs["image"],
]
output = pipe(**pipeline_inputs, num_images_per_prompt=2)
image = output.images
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
tuple_pipeline_inputs["image"] = [
tuple_pipeline_inputs["image"],
tuple_pipeline_inputs["image"],
]
image_from_tuple = pipe(
**tuple_pipeline_inputs,
num_images_per_prompt=2,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (4, 64, 64, 3)
expected_slice = np.array(
[
0.9980,
0.9997,
0.0023,
0.0029,
0.9997,
0.9985,
0.9997,
0.0010,
0.9995,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_passed_image_embed(self): def test_unclip_passed_image_embed(self):
device = torch.device("cpu") device = torch.device("cpu")

View File

@ -550,6 +550,32 @@ class PipelineTesterMixin:
_ = pipe(**inputs) _ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_num_images_per_prompt(self):
sig = inspect.signature(self.pipeline_class.__call__)
if "num_images_per_prompt" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape[0] == batch_size * num_images_per_prompt
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a