diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 18d024a7..cae3c4fe 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -271,7 +271,8 @@ class PaintByExamplePipeline(DiffusionPipeline): and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 1d34280d..71222f4a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -240,7 +240,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 87924fdf..7419d2f3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -134,6 +134,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = [p for p in prompt] + batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings @@ -212,9 +215,17 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): - if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): - raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")