allow loading ddpm models into ddim (#1932)

This commit is contained in:
Patrick von Platen 2023-01-10 14:52:32 +01:00 committed by GitHub
parent beb932c5d1
commit f6f1ec3a7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...schedulers import DDIMScheduler
from ...utils import deprecate, randn_tensor from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@ -34,6 +35,10 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
# make sure scheduler can always be converted to DDIM
scheduler = DDIMScheduler.from_config(scheduler.config)
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()