diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 0ab92eff..61008272 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -1,18 +1,33 @@ #!/usr/bin/env python3 import warnings +from typing import Optional import torch from diffusers import DiffusionPipeline +from ...models import UNet2DModel +from ...schedulers import ScoreSdeVeScheduler + class ScoreSdeVePipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): + + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs): + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn(