From 4f3ddb6ccafb353a1d7dece072462ab372bf7b75 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 7 Dec 2022 11:43:28 +0100 Subject: [PATCH] [Paint by Example] Better default for image width (#1587) --- .../pipeline_paint_by_example.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 55842e87..83f10f82 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -442,14 +442,7 @@ class PaintByExamplePipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs - self.check_inputs(example_image, height, width, callback_steps) - - # 2. Define call parameters + # 1. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): @@ -462,14 +455,18 @@ class PaintByExamplePipeline(DiffusionPipeline): # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input image + # 2. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + width, height = masked_image.shape[-2:] + + # 3. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 4. Encode input image image_embeddings = self._encode_image( example_image, device, num_images_per_prompt, do_classifier_free_guidance ) - # 4. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps