[Type hint] Latent Diffusion Uncond pipeline (#333)

This commit is contained in:
Santiago Víquez 2022-09-02 16:39:34 +02:00 committed by GitHub
parent e54206d095
commit 033b77ebc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 2 deletions

View File

@ -1,19 +1,35 @@
import inspect
import warnings
from typing import Optional
import torch
from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae, unet, scheduler):
vqvae: VQModel
unet: UNet2DModel
scheduler: DDIMScheduler
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
def __call__(
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
**kwargs,
):
# eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs: