pass extra KDiffusionSampler function parameters
This commit is contained in:
parent
6b78833e33
commit
2ab3d593f9
|
@ -37,6 +37,11 @@ samplers = [
|
||||||
]
|
]
|
||||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||||
|
|
||||||
|
sampler_extra_params = {
|
||||||
|
'sample_euler':['s_churn','s_tmin','s_noise'],
|
||||||
|
'sample_heun' :['s_churn','s_tmin','s_noise'],
|
||||||
|
'sample_dpm_2':['s_churn','s_tmin','s_noise'],
|
||||||
|
}
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
if opts.img2img_fix_steps or steps is not None:
|
if opts.img2img_fix_steps or steps is not None:
|
||||||
|
@ -224,6 +229,7 @@ class KDiffusionSampler:
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
self.extra_params = sampler_extra_params.get(funcname,[])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
@ -269,7 +275,12 @@ class KDiffusionSampler:
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
extra_params_kwargs = {}
|
||||||
|
for val in self.extra_params:
|
||||||
|
if hasattr(opts,val):
|
||||||
|
extra_params_kwargs[val] = getattr(opts,val)
|
||||||
|
|
||||||
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
@ -286,7 +297,12 @@ class KDiffusionSampler:
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
|
extra_params_kwargs = {}
|
||||||
|
for val in self.extra_params:
|
||||||
|
if hasattr(opts,val):
|
||||||
|
extra_params_kwargs[val] = getattr(opts,val)
|
||||||
|
|
||||||
|
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue