add karras scheduling variants
This commit is contained in:
parent
c1a068ed0a
commit
71901b3d3b
|
@ -26,6 +26,17 @@ samplers_k_diffusion = [
|
|||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
|
||||
]
|
||||
|
||||
if opts.show_karras_scheduler_variants:
|
||||
k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
|
||||
k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
|
||||
k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
|
||||
samplers_k_diffusion_ka = [
|
||||
('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
|
||||
('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
|
||||
('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
|
||||
]
|
||||
samplers_k_diffusion.extend(samplers_k_diffusion_ka)
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
|
||||
for label, funcname, aliases in samplers_k_diffusion
|
||||
|
@ -345,6 +356,8 @@ class KDiffusionSampler:
|
|||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.funcname.endswith('ka'):
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
x = x * sigmas[0]
|
||||
|
|
Loading…
Reference in New Issue