Don't assume 512x512 in k-diffusion pipeline (#1625)
Don't assume 512x512 in k-diffusion pipeline.
This commit is contained in:
parent
f1b726e46e
commit
ff65c2d72b
|
@ -325,8 +325,8 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
height: int = 512,
|
height: Optional[int] = None,
|
||||||
width: int = 512,
|
width: Optional[int] = None,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
@ -345,9 +345,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||||
Args:
|
Args:
|
||||||
prompt (`str` or `List[str]`):
|
prompt (`str` or `List[str]`):
|
||||||
The prompt or prompts to guide the image generation.
|
The prompt or prompts to guide the image generation.
|
||||||
height (`int`, *optional*, defaults to 512):
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The height in pixels of the generated image.
|
The height in pixels of the generated image.
|
||||||
width (`int`, *optional*, defaults to 512):
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
The width in pixels of the generated image.
|
The width in pixels of the generated image.
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
@ -393,6 +393,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(prompt, height, width, callback_steps)
|
self.check_inputs(prompt, height, width, callback_steps)
|
||||||
|
|
Loading…
Reference in New Issue