diff --git a/utils/sample_generator.py b/utils/sample_generator.py index 929d3b8..ad40123 100644 --- a/utils/sample_generator.py +++ b/utils/sample_generator.py @@ -8,7 +8,10 @@ from typing import Generator, Callable, Any import torch from PIL import Image, ImageDraw, ImageFont from colorama import Fore, Style -from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler +from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, + PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, + KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler) + from torch import FloatTensor from torch.cuda.amp import autocast from torch.utils.tensorboard import SummaryWriter @@ -310,14 +313,30 @@ class SampleGenerator: @torch.no_grad() def _create_scheduler(self, scheduler_config: dict): scheduler = self.scheduler - if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']: + if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++', + 'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde', + 'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']: print(f"unsupported scheduler '{self.scheduler}', falling back to ddim") scheduler = 'ddim' if scheduler == 'ddim': return DDIMScheduler.from_config(scheduler_config) - elif scheduler == 'dpm++': - return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++") + elif scheduler == 'dpm++_2s': + return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False) + elif scheduler == 'dpm++_2s_k': + return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True) + elif scheduler == 'dpm++' or scheduler == 'dpm++_2m': + return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False) + elif scheduler == 'dpm++_2m_k': + return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True) + elif scheduler == 'dpm++_sde': + return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0) + elif scheduler == 'dpm++_sde_k': + return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0) + elif scheduler == 'dpm++_2m_sde': + return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False) + elif scheduler == 'dpm++_2m_sde_k': + return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True) elif scheduler == 'pndm': return PNDMScheduler.from_config(scheduler_config) elif scheduler == 'ddpm':