add samplers
This commit is contained in:
parent
24b17efb08
commit
4b824f3acd
|
@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
|
|||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from colorama import Fore, Style
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
|
||||
from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
|
||||
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
|
||||
|
||||
from torch import FloatTensor
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
@ -307,14 +310,30 @@ class SampleGenerator:
|
|||
@torch.no_grad()
|
||||
def _create_scheduler(self, scheduler_config: dict):
|
||||
scheduler = self.scheduler
|
||||
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']:
|
||||
if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
|
||||
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
|
||||
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
|
||||
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
|
||||
scheduler = 'ddim'
|
||||
|
||||
if scheduler == 'ddim':
|
||||
return DDIMScheduler.from_config(scheduler_config)
|
||||
elif scheduler == 'dpm++':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
|
||||
elif scheduler == 'dpm++_2s':
|
||||
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
|
||||
elif scheduler == 'dpm++_2s_k':
|
||||
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
|
||||
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
|
||||
elif scheduler == 'dpm++_2m_k':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
|
||||
elif scheduler == 'dpm++_sde':
|
||||
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
|
||||
elif scheduler == 'dpm++_sde_k':
|
||||
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
|
||||
elif scheduler == 'dpm++_2m_sde':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
|
||||
elif scheduler == 'dpm++_2m_sde_k':
|
||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
|
||||
elif scheduler == 'pndm':
|
||||
return PNDMScheduler.from_config(scheduler_config)
|
||||
elif scheduler == 'ddpm':
|
||||
|
|
Loading…
Reference in New Issue