allow custom height, width in StableDiffusionPipeline (#179)
* allow custom height width * raise if height width are not mul of 8
This commit is contained in:
parent
c25d8c905c
commit
5f25818a0f
|
@ -28,6 +28,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
|
height: Optional[int] = 512,
|
||||||
|
width: Optional[int] = 512,
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: Optional[int] = 50,
|
||||||
guidance_scale: Optional[float] = 1.0,
|
guidance_scale: Optional[float] = 1.0,
|
||||||
eta: Optional[float] = 0.0,
|
eta: Optional[float] = 0.0,
|
||||||
|
@ -45,6 +47,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
self.unet.to(torch_device)
|
self.unet.to(torch_device)
|
||||||
self.vae.to(torch_device)
|
self.vae.to(torch_device)
|
||||||
self.text_encoder.to(torch_device)
|
self.text_encoder.to(torch_device)
|
||||||
|
@ -72,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# get the intial random noise
|
# get the intial random noise
|
||||||
latents = torch.randn(
|
latents = torch.randn(
|
||||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
)
|
)
|
||||||
latents = latents.to(torch_device)
|
latents = latents.to(torch_device)
|
||||||
|
|
Loading…
Reference in New Issue