[Type hint] Latent Diffusion Uncond pipeline (#333)
This commit is contained in:
parent
e54206d095
commit
033b77ebc4
|
@ -1,19 +1,35 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel, VQModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler
|
||||
|
||||
|
||||
class LDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, unet, scheduler):
|
||||
|
||||
vqvae: VQModel
|
||||
unet: UNet2DModel
|
||||
scheduler: DDIMScheduler
|
||||
|
||||
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
|
|
Loading…
Reference in New Issue