VersatileDiffusion: fix input processing (#1568)
* fix versatile diffusion input * merge main * `make fix-copies` Co-authored-by: anton- <anton@huggingface.co>
This commit is contained in:
parent
31444f5790
commit
589330595d
|
@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
|||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
|
|
|
@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
|
|
|
@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4:
|
||||
prompt = [p for p in prompt]
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
|
@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
|
||||
def check_inputs(self, image, height, width, callback_steps):
|
||||
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
||||
f" {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
|
Loading…
Reference in New Issue