From 5164c9faa9d40f55fcc4ff1940f1171a5d43b9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Santiago=20V=C3=ADquez?= Date: Thu, 1 Sep 2022 22:17:00 +0200 Subject: [PATCH] [Type hint] Score SDE VE pipeline (#325) --- .../score_sde_ve/pipeline_score_sde_ve.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 0ab92eff..61008272 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -1,18 +1,33 @@ #!/usr/bin/env python3 import warnings +from typing import Optional import torch from diffusers import DiffusionPipeline +from ...models import UNet2DModel +from ...schedulers import ScoreSdeVeScheduler + class ScoreSdeVePipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): + + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs): + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn(