[Type hint] Score SDE VE pipeline (#325)
This commit is contained in:
parent
93debd301d
commit
5164c9faa9
|
@ -1,18 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import ScoreSdeVeScheduler
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
|
||||
unet: UNet2DModel
|
||||
scheduler: ScoreSdeVeScheduler
|
||||
|
||||
def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
num_inference_steps: int = 2000,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
|
|
Loading…
Reference in New Issue