move test num_images_per_prompt to pipeline mixin (#2488)
* attend and excite batch test causing timeouts * move test num_images_per_prompt to pipeline mixin * style * prompt_key -> self.batch_params
This commit is contained in:
parent
2f489571a7
commit
a72a057d62
|
@ -517,8 +517,30 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (indices is None) or (indices is not None and not isinstance(indices, List)):
|
indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
|
||||||
raise ValueError(f"`indices` has to be a list but is {type(indices)}")
|
indices_is_list_list_ints = (
|
||||||
|
isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not indices_is_list_ints and not indices_is_list_list_ints:
|
||||||
|
raise TypeError("`indices` must be a list of ints or a list of a list of ints")
|
||||||
|
|
||||||
|
if indices_is_list_ints:
|
||||||
|
indices_batch_size = 1
|
||||||
|
elif indices_is_list_list_ints:
|
||||||
|
indices_batch_size = len(indices)
|
||||||
|
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
prompt_batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
prompt_batch_size = len(prompt)
|
||||||
|
elif prompt_embeds is not None:
|
||||||
|
prompt_batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
if indices_batch_size != prompt_batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||||
|
@ -675,7 +697,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
token_indices: List[int],
|
token_indices: Union[List[int], List[List[int]]],
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
|
@ -851,7 +873,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||||
|
|
||||||
if isinstance(token_indices[0], int):
|
if isinstance(token_indices[0], int):
|
||||||
token_indices = [token_indices]
|
token_indices = [token_indices]
|
||||||
|
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
for ind in token_indices:
|
for ind in token_indices:
|
||||||
indices = indices + [ind] * num_images_per_prompt
|
indices = indices + [ind] * num_images_per_prompt
|
||||||
|
|
||||||
|
|
|
@ -160,19 +160,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
assert out_1.shape == (1, 64, 64, 3)
|
assert out_1.shape == (1, 64, 64, 3)
|
||||||
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
|
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
|
||||||
|
|
||||||
def test_paint_by_example_inpaint_with_num_images_per_prompt(self):
|
|
||||||
device = "cpu"
|
|
||||||
pipe = PaintByExamplePipeline(**self.get_dummy_components())
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs()
|
|
||||||
|
|
||||||
images = pipe(**inputs, num_images_per_prompt=2).images
|
|
||||||
|
|
||||||
# check if the output is a list of 2 images
|
|
||||||
assert len(images) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
|
@ -41,7 +41,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
"negative_prompt_embeds",
|
"negative_prompt_embeds",
|
||||||
}
|
}
|
||||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
|
||||||
|
|
||||||
def get_dummy_components(self):
|
def get_dummy_components(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
|
@ -477,43 +477,6 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
|
||||||
sd_pipe = StableDiffusionPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
prompt = "A painting of a squirrel eating a burger"
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
|
|
||||||
|
|
||||||
assert images.shape == (1, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
images = sd_pipe(
|
|
||||||
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
|
||||||
).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
images = sd_pipe(
|
|
||||||
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
|
||||||
).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
def test_stable_diffusion_long_prompt(self):
|
def test_stable_diffusion_long_prompt(self):
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
|
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
|
||||||
|
|
|
@ -143,42 +143,6 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
|
||||||
device = "cpu"
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionImageVariationPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of images
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["image"] = batch_size * [inputs["image"]]
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["image"] = batch_size * [inputs["image"]]
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
|
@ -181,42 +181,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_img2img_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
@skip_mps
|
@skip_mps
|
||||||
def test_save_load_local(self):
|
def test_save_load_local(self):
|
||||||
return super().test_save_load_local()
|
return super().test_save_load_local()
|
||||||
|
|
|
@ -151,19 +151,6 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||||
assert out_pil.shape == (1, 64, 64, 3)
|
assert out_pil.shape == (1, 64, 64, 3)
|
||||||
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
|
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
|
||||||
|
|
||||||
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
|
|
||||||
device = "cpu"
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=2).images
|
|
||||||
|
|
||||||
# check if the output is a list of 2 images
|
|
||||||
assert len(images) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
|
@ -191,42 +191,6 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
|
@ -177,42 +177,6 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = sd_pipe(**inputs).images
|
_ = sd_pipe(**inputs).images
|
||||||
|
|
||||||
def test_stable_diffusion_panorama_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionPanoramaPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|
|
@ -191,34 +191,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
|
||||||
sd_pipe = sd_pipe.to(device)
|
|
||||||
sd_pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=2 for a single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
|
||||||
|
|
||||||
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
||||||
@unittest.skip("non-deterministic pipeline")
|
@unittest.skip("non-deterministic pipeline")
|
||||||
def test_inference_batch_single_identical(self):
|
def test_inference_batch_single_identical(self):
|
||||||
|
|
|
@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
test_save_load_optional_components = False
|
test_save_load_optional_components = False
|
||||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
|
||||||
|
|
||||||
def get_dummy_components(self):
|
def get_dummy_components(self):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -340,42 +340,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
|
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
pipe = StableDiffusionDepth2ImgPipeline(**components)
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default)
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (1, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = pipe(**inputs).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for single prompt
|
|
||||||
num_images_per_prompt = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
# test num_images_per_prompt for batch of prompts
|
|
||||||
batch_size = 2
|
|
||||||
inputs = self.get_dummy_inputs(device)
|
|
||||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
|
||||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
|
||||||
|
|
||||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
|
||||||
|
|
||||||
def test_stable_diffusion_depth2img_pil(self):
|
def test_stable_diffusion_depth2img_pil(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
components = self.get_dummy_components()
|
components = self.get_dummy_components()
|
||||||
|
|
|
@ -361,59 +361,6 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_unclip_image_variation_input_num_images_per_prompt(self):
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
components = self.get_dummy_components()
|
|
||||||
|
|
||||||
pipe = self.pipeline_class(**components)
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
|
|
||||||
pipe.set_progress_bar_config(disable=None)
|
|
||||||
|
|
||||||
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
|
|
||||||
pipeline_inputs["image"] = [
|
|
||||||
pipeline_inputs["image"],
|
|
||||||
pipeline_inputs["image"],
|
|
||||||
]
|
|
||||||
|
|
||||||
output = pipe(**pipeline_inputs, num_images_per_prompt=2)
|
|
||||||
image = output.images
|
|
||||||
|
|
||||||
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
|
|
||||||
tuple_pipeline_inputs["image"] = [
|
|
||||||
tuple_pipeline_inputs["image"],
|
|
||||||
tuple_pipeline_inputs["image"],
|
|
||||||
]
|
|
||||||
|
|
||||||
image_from_tuple = pipe(
|
|
||||||
**tuple_pipeline_inputs,
|
|
||||||
num_images_per_prompt=2,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
image_slice = image[0, -3:, -3:, -1]
|
|
||||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
|
||||||
|
|
||||||
assert image.shape == (4, 64, 64, 3)
|
|
||||||
|
|
||||||
expected_slice = np.array(
|
|
||||||
[
|
|
||||||
0.9980,
|
|
||||||
0.9997,
|
|
||||||
0.0023,
|
|
||||||
0.0029,
|
|
||||||
0.9997,
|
|
||||||
0.9985,
|
|
||||||
0.9997,
|
|
||||||
0.0010,
|
|
||||||
0.9995,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
|
||||||
|
|
||||||
def test_unclip_passed_image_embed(self):
|
def test_unclip_passed_image_embed(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
|
@ -550,6 +550,32 @@ class PipelineTesterMixin:
|
||||||
_ = pipe(**inputs)
|
_ = pipe(**inputs)
|
||||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||||
|
|
||||||
|
def test_num_images_per_prompt(self):
|
||||||
|
sig = inspect.signature(self.pipeline_class.__call__)
|
||||||
|
|
||||||
|
if "num_images_per_prompt" not in sig.parameters:
|
||||||
|
return
|
||||||
|
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
pipe = self.pipeline_class(**components)
|
||||||
|
pipe = pipe.to(torch_device)
|
||||||
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
batch_sizes = [1, 2]
|
||||||
|
num_images_per_prompts = [1, 2]
|
||||||
|
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
for num_images_per_prompt in num_images_per_prompts:
|
||||||
|
inputs = self.get_dummy_inputs(torch_device)
|
||||||
|
|
||||||
|
for key in inputs.keys():
|
||||||
|
if key in self.batch_params:
|
||||||
|
inputs[key] = batch_size * [inputs[key]]
|
||||||
|
|
||||||
|
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||||
|
|
||||||
|
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||||
|
|
||||||
|
|
||||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||||
|
|
Loading…
Reference in New Issue