allow loading ddpm models into ddim (#1932)
This commit is contained in:
parent
beb932c5d1
commit
f6f1ec3a7c
|
@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...schedulers import DDIMScheduler
|
||||||
from ...utils import deprecate, randn_tensor
|
from ...utils import deprecate, randn_tensor
|
||||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||||
|
|
||||||
|
@ -34,6 +35,10 @@ class DDIMPipeline(DiffusionPipeline):
|
||||||
|
|
||||||
def __init__(self, unet, scheduler):
|
def __init__(self, unet, scheduler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# make sure scheduler can always be converted to DDIM
|
||||||
|
scheduler = DDIMScheduler.from_config(scheduler.config)
|
||||||
|
|
||||||
self.register_modules(unet=unet, scheduler=scheduler)
|
self.register_modules(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue