add samplers
This commit is contained in:
parent
24b17efb08
commit
4b824f3acd
|
@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
|
from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
|
||||||
|
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
|
||||||
|
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
|
||||||
|
|
||||||
from torch import FloatTensor
|
from torch import FloatTensor
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
@ -307,14 +310,30 @@ class SampleGenerator:
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _create_scheduler(self, scheduler_config: dict):
|
def _create_scheduler(self, scheduler_config: dict):
|
||||||
scheduler = self.scheduler
|
scheduler = self.scheduler
|
||||||
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']:
|
if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
|
||||||
|
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
|
||||||
|
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
|
||||||
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
|
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
|
||||||
scheduler = 'ddim'
|
scheduler = 'ddim'
|
||||||
|
|
||||||
if scheduler == 'ddim':
|
if scheduler == 'ddim':
|
||||||
return DDIMScheduler.from_config(scheduler_config)
|
return DDIMScheduler.from_config(scheduler_config)
|
||||||
elif scheduler == 'dpm++':
|
elif scheduler == 'dpm++_2s':
|
||||||
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
|
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
|
||||||
|
elif scheduler == 'dpm++_2s_k':
|
||||||
|
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
|
||||||
|
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
|
||||||
|
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
|
||||||
|
elif scheduler == 'dpm++_2m_k':
|
||||||
|
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
|
||||||
|
elif scheduler == 'dpm++_sde':
|
||||||
|
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
|
||||||
|
elif scheduler == 'dpm++_sde_k':
|
||||||
|
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
|
||||||
|
elif scheduler == 'dpm++_2m_sde':
|
||||||
|
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
|
||||||
|
elif scheduler == 'dpm++_2m_sde_k':
|
||||||
|
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
|
||||||
elif scheduler == 'pndm':
|
elif scheduler == 'pndm':
|
||||||
return PNDMScheduler.from_config(scheduler_config)
|
return PNDMScheduler.from_config(scheduler_config)
|
||||||
elif scheduler == 'ddpm':
|
elif scheduler == 'ddpm':
|
||||||
|
|
Loading…
Reference in New Issue