Merge pull request #14978 from drhead/refiner_fix
Make refiner switchover based on model timesteps instead of sampling steps
This commit is contained in:
commit
aabedcbcc7
|
@ -152,7 +152,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
if sd_samplers_common.apply_refiner(self):
|
if sd_samplers_common.apply_refiner(self, sigma):
|
||||||
cond = self.sampler.sampler_extra_args['cond']
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
uncond = self.sampler.sampler_extra_args['uncond']
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
|
|
|
@ -155,8 +155,17 @@ def replace_torchsde_browinan():
|
||||||
replace_torchsde_browinan()
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
def apply_refiner(cfg_denoiser):
|
def apply_refiner(cfg_denoiser, sigma):
|
||||||
|
if opts.refiner_switch_by_sample_steps:
|
||||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
|
else:
|
||||||
|
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
|
||||||
|
try:
|
||||||
|
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
|
||||||
|
except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep
|
||||||
|
timestep = torch.max(sigma).to(dtype=int)
|
||||||
|
completed_ratio = (999 - timestep) / 1000
|
||||||
|
|
||||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,8 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
|
||||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||||
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
|
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod"),
|
||||||
|
"refiner_switch_by_sample_steps": OptionInfo(False, "Switch to refiner by sampling steps instead of model timesteps. Old behavior for refiner.", infotext="Refiner switch by sampling steps")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
|
|
Loading…
Reference in New Issue