From 5f25818a0fd8747c46b27becc9c63dcfbbfeb638 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 15 Aug 2022 10:28:03 +0530 Subject: [PATCH] allow custom height, width in StableDiffusionPipeline (#179) * allow custom height width * raise if height width are not mul of 8 --- .../stable_diffusion/pipeline_stable_diffusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 0f309625..9db646af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -28,6 +28,8 @@ class StableDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, @@ -45,6 +47,9 @@ class StableDiffusionPipeline(DiffusionPipeline): else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + self.unet.to(torch_device) self.vae.to(torch_device) self.text_encoder.to(torch_device) @@ -72,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline): # get the intial random noise latents = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(torch_device)