[Paint by Example] Better default for image width (#1587)
This commit is contained in:
parent
4eb9ad0d1c
commit
4f3ddb6cca
|
@ -442,14 +442,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
|||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(example_image, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
# 1. Define call parameters
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
|
@ -462,14 +455,18 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
|||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input image
|
||||
# 2. Preprocess mask and image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
width, height = masked_image.shape[-2:]
|
||||
|
||||
# 3. Check inputs
|
||||
self.check_inputs(example_image, height, width, callback_steps)
|
||||
|
||||
# 4. Encode input image
|
||||
image_embeddings = self._encode_image(
|
||||
example_image, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# 4. Preprocess mask and image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
|
Loading…
Reference in New Issue