eliminate duplicated code
add an option to samplers for skipping next to last sigma
This commit is contained in:
parent
5667ec4ca7
commit
399b229783
|
@ -23,16 +23,16 @@ samplers_k_diffusion = [
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
||||||
|
@ -444,9 +444,7 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
return extra_params_kwargs
|
return extra_params_kwargs
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def get_sigmas(self, p, steps):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||||
|
@ -454,9 +452,16 @@ class KDiffusionSampler:
|
||||||
else:
|
else:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
|
if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||||
|
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
|
@ -488,18 +493,10 @@ class KDiffusionSampler:
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
sigmas = self.get_sigmas(p, steps)
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
|
||||||
else:
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
|
||||||
|
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
|
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
if 'sigma_min' in inspect.signature(self.func).parameters:
|
if 'sigma_min' in inspect.signature(self.func).parameters:
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
|
|
Loading…
Reference in New Issue