allow loading ddpm models into ddim (#1932)
This commit is contained in:
parent
beb932c5d1
commit
f6f1ec3a7c
|
@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import deprecate, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
@ -34,6 +35,10 @@ class DDIMPipeline(DiffusionPipeline):
|
|||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
# make sure scheduler can always be converted to DDIM
|
||||
scheduler = DDIMScheduler.from_config(scheduler.config)
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue