Merge pull request #232 from damian0815/feat_add_sde_samplers

Feat add sde samplers
This commit is contained in:
Victor Hall 2023-11-03 17:37:25 -04:00 committed by GitHub
commit 2bda841b2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 4 deletions

View File

@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
from torch import FloatTensor from torch import FloatTensor
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -310,14 +313,30 @@ class SampleGenerator:
@torch.no_grad() @torch.no_grad()
def _create_scheduler(self, scheduler_config: dict): def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']: if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim") print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim' scheduler = 'ddim'
if scheduler == 'ddim': if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config) return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++': elif scheduler == 'dpm++_2s':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++") return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
elif scheduler == 'dpm++_2s_k':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'dpm++_sde':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
elif scheduler == 'dpm++_sde_k':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
elif scheduler == 'dpm++_2m_sde':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_sde_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'pndm': elif scheduler == 'pndm':
return PNDMScheduler.from_config(scheduler_config) return PNDMScheduler.from_config(scheduler_config)
elif scheduler == 'ddpm': elif scheduler == 'ddpm':