add num_inference_steps arg to DDPM (#935)

This commit is contained in:
Tanishq Abraham 2022-10-25 04:08:56 -07:00 committed by GitHub
parent 82044153df
commit 6e099e2c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -42,6 +42,7 @@ class DDPMPipeline(DiffusionPipeline):
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
@ -53,6 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@ -73,7 +77,7 @@ class DDPMPipeline(DiffusionPipeline):
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(1000)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output