Add s_noise param to more samplers
This commit is contained in:
parent
1ae9dacb4b
commit
d1a70c3f05
|
@ -276,19 +276,19 @@ class Sampler:
|
|||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
|
|
@ -45,6 +45,12 @@ sampler_extra_params = {
|
|||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_fast': ['s_noise'],
|
||||
'sample_dpm_2_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_sde': ['s_noise'],
|
||||
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||
}
|
||||
|
||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
|
|
Loading…
Reference in New Issue