add num_inference_steps arg to DDPM (#935)
This commit is contained in:
parent
82044153df
commit
6e099e2c8c
|
@ -42,6 +42,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
num_inference_steps: int = 1000,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
|
@ -53,6 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 1000):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
|
@ -73,7 +77,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(1000)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
|
|
Loading…
Reference in New Issue