support scheduler selection in hires fix
This commit is contained in:
parent
755d2cb2e5
commit
9aa9e980a9
|
@ -314,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||||
if "Hires sampler" not in res:
|
if "Hires sampler" not in res:
|
||||||
res["Hires sampler"] = "Use same sampler"
|
res["Hires sampler"] = "Use same sampler"
|
||||||
|
|
||||||
|
if "Hires schedule type" not in res:
|
||||||
|
res["Hires schedule type"] = "Use same scheduler"
|
||||||
|
|
||||||
if "Hires checkpoint" not in res:
|
if "Hires checkpoint" not in res:
|
||||||
res["Hires checkpoint"] = "Use same checkpoint"
|
res["Hires checkpoint"] = "Use same checkpoint"
|
||||||
|
|
||||||
|
|
|
@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
hr_resize_y: int = 0
|
hr_resize_y: int = 0
|
||||||
hr_checkpoint_name: str = None
|
hr_checkpoint_name: str = None
|
||||||
hr_sampler_name: str = None
|
hr_sampler_name: str = None
|
||||||
|
hr_scheduler: str = None
|
||||||
hr_prompt: str = ''
|
hr_prompt: str = ''
|
||||||
hr_negative_prompt: str = ''
|
hr_negative_prompt: str = ''
|
||||||
force_task_id: str = None
|
force_task_id: str = None
|
||||||
|
@ -1203,6 +1204,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||||
|
|
||||||
|
self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
|
||||||
|
|
||||||
|
if self.hr_scheduler is None:
|
||||||
|
self.hr_scheduler = self.scheduler
|
||||||
|
|
||||||
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
if self.enable_hr and self.latent_scale_mode is None:
|
if self.enable_hr and self.latent_scale_mode is None:
|
||||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||||
|
|
|
@ -1,44 +1,10 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import functools
|
|
||||||
|
|
||||||
from modules import scripts, sd_samplers, sd_schedulers, shared
|
from modules import scripts, sd_samplers, sd_schedulers, shared
|
||||||
from modules.infotext_utils import PasteField
|
from modules.infotext_utils import PasteField
|
||||||
from modules.ui_components import FormRow, FormGroup
|
from modules.ui_components import FormRow, FormGroup
|
||||||
|
|
||||||
|
|
||||||
def get_sampler_from_infotext(d: dict):
|
|
||||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_from_infotext(d: dict):
|
|
||||||
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
|
||||||
default_sampler = sd_samplers.samplers[0]
|
|
||||||
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
|
||||||
|
|
||||||
name = sampler_name or default_sampler.name
|
|
||||||
|
|
||||||
for scheduler in sd_schedulers.schedulers:
|
|
||||||
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
|
||||||
|
|
||||||
for name_option in name_options:
|
|
||||||
if name.endswith(" " + name_option):
|
|
||||||
found_scheduler = scheduler
|
|
||||||
name = name[0:-(len(name_option) + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
|
|
||||||
|
|
||||||
# revert back to Automatic if it's the default scheduler for the selected sampler
|
|
||||||
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
|
||||||
found_scheduler = sd_schedulers.schedulers[0]
|
|
||||||
|
|
||||||
return sampler.name, found_scheduler.label
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptSampler(scripts.ScriptBuiltinUI):
|
class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||||
section = "sampler"
|
section = "sampler"
|
||||||
|
|
||||||
|
@ -67,8 +33,8 @@ class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||||
|
|
||||||
self.infotext_fields = [
|
self.infotext_fields = [
|
||||||
PasteField(self.steps, "Steps", api="steps"),
|
PasteField(self.steps, "Steps", api="steps"),
|
||||||
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
|
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
|
||||||
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
|
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return self.steps, self.sampler_name, self.scheduler
|
return self.steps, self.sampler_name, self.scheduler
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
|
import functools
|
||||||
|
|
||||||
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
||||||
|
@ -64,4 +66,60 @@ def visible_samplers():
|
||||||
return [x for x in samplers if x.name not in samplers_hidden]
|
return [x for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_sampler_and_scheduler(d: dict):
|
||||||
|
hr_sampler = d.get("Hires sampler", "Use same sampler")
|
||||||
|
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
|
||||||
|
|
||||||
|
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
|
||||||
|
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
|
||||||
|
|
||||||
|
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
|
||||||
|
|
||||||
|
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
|
||||||
|
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
|
||||||
|
|
||||||
|
return sampler, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_sampler_from_infotext(d: dict):
|
||||||
|
return get_hr_sampler_and_scheduler(d)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_scheduler_from_infotext(d: dict):
|
||||||
|
return get_hr_sampler_and_scheduler(d)[1]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||||
|
default_sampler = samplers[0]
|
||||||
|
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||||
|
|
||||||
|
name = sampler_name or default_sampler.name
|
||||||
|
|
||||||
|
for scheduler in sd_schedulers.schedulers:
|
||||||
|
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||||
|
|
||||||
|
for name_option in name_options:
|
||||||
|
if name.endswith(" " + name_option):
|
||||||
|
found_scheduler = scheduler
|
||||||
|
name = name[0:-(len(name_option) + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
|
sampler = all_samplers_map.get(name, default_sampler)
|
||||||
|
|
||||||
|
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||||
|
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||||
|
found_scheduler = sd_schedulers.schedulers[0]
|
||||||
|
|
||||||
|
return sampler.name, found_scheduler.label
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
|
|
@ -79,7 +79,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
|
|
||||||
steps += 1 if discard_next_to_last_sigma else 0
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
scheduler_name = p.scheduler or 'Automatic'
|
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
|
||||||
if scheduler_name == 'Automatic':
|
if scheduler_name == 'Automatic':
|
||||||
scheduler_name = self.config.options.get('scheduler', None)
|
scheduler_name = self.config.options.get('scheduler', None)
|
||||||
|
|
||||||
|
@ -95,8 +95,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
else:
|
else:
|
||||||
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||||
|
|
||||||
if scheduler.label != 'Automatic':
|
if scheduler.label != 'Automatic' and not p.is_hr_pass:
|
||||||
p.extra_generation_params["Schedule type"] = scheduler.label
|
p.extra_generation_params["Schedule type"] = scheduler.label
|
||||||
|
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
|
||||||
|
p.extra_generation_params["Hires schedule type"] = scheduler.label
|
||||||
|
|
||||||
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
||||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||||
|
|
|
@ -11,7 +11,7 @@ from PIL import Image
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
if force_enable_hr:
|
if force_enable_hr:
|
||||||
|
@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
||||||
|
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
|
|
|
@ -322,10 +322,11 @@ def create_ui():
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||||
|
|
||||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||||
|
hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
|
@ -394,6 +395,7 @@ def create_ui():
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
hr_checkpoint_name,
|
hr_checkpoint_name,
|
||||||
hr_sampler_name,
|
hr_sampler_name,
|
||||||
|
hr_scheduler,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
override_settings,
|
override_settings,
|
||||||
|
@ -456,8 +458,9 @@ def create_ui():
|
||||||
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
||||||
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
||||||
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
||||||
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
|
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
|
||||||
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
|
||||||
|
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
|
||||||
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
||||||
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
||||||
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||||
|
|
Loading…
Reference in New Issue