add samplers

This commit is contained in:
Damian Stewart 2023-08-10 20:47:28 +02:00
parent 24b17efb08
commit 4b824f3acd
1 changed files with 23 additions and 4 deletions

View File

@ -8,7 +8,10 @@ from typing import Generator, Callable, Any
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from colorama import Fore, Style from colorama import Fore, Style
from diffusers import StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler, KDPM2AncestralDiscreteScheduler from diffusers import (StableDiffusionPipeline, DDIMScheduler, DPMSolverMultistepScheduler, DDPMScheduler,
PNDMScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler,
KDPM2AncestralDiscreteScheduler, DPMSolverSDEScheduler, DPMSolverSinglestepScheduler)
from torch import FloatTensor from torch import FloatTensor
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -307,14 +310,30 @@ class SampleGenerator:
@torch.no_grad() @torch.no_grad()
def _create_scheduler(self, scheduler_config: dict): def _create_scheduler(self, scheduler_config: dict):
scheduler = self.scheduler scheduler = self.scheduler
if scheduler not in ['ddim', 'dpm++', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2']: if scheduler not in ['ddim', 'pndm', 'ddpm', 'lms', 'euler', 'euler_a', 'kdpm2', 'dpm++',
'dpm++_2s', 'dpm++_2m', 'dpm++_sde', 'dpm++_2m_sde',
'dpm++_2s_k', 'dpm++_2m_k', 'dpm++_sde_k', 'dpm++_2m_sde_k']:
print(f"unsupported scheduler '{self.scheduler}', falling back to ddim") print(f"unsupported scheduler '{self.scheduler}', falling back to ddim")
scheduler = 'ddim' scheduler = 'ddim'
if scheduler == 'ddim': if scheduler == 'ddim':
return DDIMScheduler.from_config(scheduler_config) return DDIMScheduler.from_config(scheduler_config)
elif scheduler == 'dpm++': elif scheduler == 'dpm++_2s':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++") return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=False)
elif scheduler == 'dpm++_2s_k':
return DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True)
elif scheduler == 'dpm++' or scheduler == 'dpm++_2m':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'dpm++_sde':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=False, noise_sampler_seed=0)
elif scheduler == 'dpm++_sde_k':
return DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True, noise_sampler_seed=0)
elif scheduler == 'dpm++_2m_sde':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=False)
elif scheduler == 'dpm++_2m_sde_k':
return DPMSolverMultistepScheduler.from_config(scheduler_config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif scheduler == 'pndm': elif scheduler == 'pndm':
return PNDMScheduler.from_config(scheduler_config) return PNDMScheduler.from_config(scheduler_config)
elif scheduler == 'ddpm': elif scheduler == 'ddpm':