add samplers

This commit is contained in:
Damian Stewart 2023-08-10 20:47:28 +02:00
parent 24b17efb08
commit 4b824f3acd
1 changed files with 23 additions and 4 deletions

View File

@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
import torch
from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
from torch import FloatTensor
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
@ -307,14 +310,30 @@ class SampleGenerator:
@torch.no_grad()
def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']:
if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim'
if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
elif scheduler == 'dpm++_2s':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
elif scheduler == 'dpm++_2s_k':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'dpm++_sde':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
elif scheduler == 'dpm++_sde_k':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
elif scheduler == 'dpm++_2m_sde':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_sde_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'pndm':
return PNDMScheduler.from_config(scheduler_config)
elif scheduler == 'ddpm':