[Type hint] Score SDE VE pipeline (#325)

This commit is contained in:
Santiago Víquez 2022-09-01 22:17:00 +02:00 committed by GitHub
parent 93debd301d
commit 5164c9faa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 2 deletions

View File

@ -1,18 +1,33 @@
#!/usr/bin/env python3
import warnings
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, unet, scheduler):
unet: UNet2DModel
scheduler: ScoreSdeVeScheduler
def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, output_type="pil", **kwargs):
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 2000,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(