Merge pull request #232 from damian0815/feat_add_sde_samplers

Feat add sde samplers
This commit is contained in:
Victor Hall 2023-11-03 17:37:25 -04:00 committed by GitHub
commit 2bda841b2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 4 deletions

View File

@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
import torch
from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler
from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
from torch import FloatTensor
from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter
@ -310,14 +313,30 @@ class SampleGenerator:
@torch.no_grad()
def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']:
if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim'
if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++")
elif scheduler == 'dpm++_2s':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
elif scheduler == 'dpm++_2s_k':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'dpm++_sde':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
elif scheduler == 'dpm++_sde_k':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
elif scheduler == 'dpm++_2m_sde':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_sde_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'pndm':
return PNDMScheduler.from_config(scheduler_config)
elif scheduler == 'ddpm':