VersatileDiffusion: fix input processing (#1568)

* fix versatile diffusion input

* merge main

* `make fix-copies`

Co-authored-by: anton- <anton@huggingface.co>
This commit is contained in:
Lukas Struppek 2022-12-12 13:45:27 +01:00 committed by GitHub
parent 31444f5790
commit 589330595d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 4 deletions

View File

@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:

View File

@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:

View File

@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
prompt = [p for p in prompt]
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
def check_inputs(self, image, height, width, callback_steps):
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
f" {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")