support scheduler selection in hires fix

This commit is contained in:
AUTOMATIC1111 2024-03-24 11:00:16 +03:00
parent 755d2cb2e5
commit 9aa9e980a9
7 changed files with 83 additions and 44 deletions

View File

@ -314,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires sampler" not in res: if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler" res["Hires sampler"] = "Use same sampler"
if "Hires schedule type" not in res:
res["Hires schedule type"] = "Use same scheduler"
if "Hires checkpoint" not in res: if "Hires checkpoint" not in res:
res["Hires checkpoint"] = "Use same checkpoint" res["Hires checkpoint"] = "Use same checkpoint"

View File

@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_resize_y: int = 0 hr_resize_y: int = 0
hr_checkpoint_name: str = None hr_checkpoint_name: str = None
hr_sampler_name: str = None hr_sampler_name: str = None
hr_scheduler: str = None
hr_prompt: str = '' hr_prompt: str = ''
hr_negative_prompt: str = '' hr_negative_prompt: str = ''
force_task_id: str = None force_task_id: str = None
@ -1203,6 +1204,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
if self.hr_scheduler is None:
self.hr_scheduler = self.scheduler
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and self.latent_scale_mode is None: if self.enable_hr and self.latent_scale_mode is None:
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):

View File

@ -1,44 +1,10 @@
import gradio as gr import gradio as gr
import functools
from modules import scripts, sd_samplers, sd_schedulers, shared from modules import scripts, sd_samplers, sd_schedulers, shared
from modules.infotext_utils import PasteField from modules.infotext_utils import PasteField
from modules.ui_components import FormRow, FormGroup from modules.ui_components import FormRow, FormGroup
def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = sd_samplers.samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
name = sampler_name or default_sampler.name
for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break
sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]
return sampler.name, found_scheduler.label
class ScriptSampler(scripts.ScriptBuiltinUI): class ScriptSampler(scripts.ScriptBuiltinUI):
section = "sampler" section = "sampler"
@ -67,8 +33,8 @@ class ScriptSampler(scripts.ScriptBuiltinUI):
self.infotext_fields = [ self.infotext_fields = [
PasteField(self.steps, "Steps", api="steps"), PasteField(self.steps, "Steps", api="steps"),
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"), PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"), PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
] ]
return self.steps, self.sampler_name, self.scheduler return self.steps, self.sampler_name, self.scheduler

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common import functools
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
# imports for functions that previously were here and are used by other modules # imports for functions that previously were here and are used by other modules
samples_to_image_grid = sd_samplers_common.samples_to_image_grid samples_to_image_grid = sd_samplers_common.samples_to_image_grid
@ -64,4 +66,60 @@ def visible_samplers():
return [x for x in samplers if x.name not in samplers_hidden] return [x for x in samplers if x.name not in samplers_hidden]
def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
def get_hr_sampler_and_scheduler(d: dict):
hr_sampler = d.get("Hires sampler", "Use same sampler")
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
return sampler, scheduler
def get_hr_sampler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[0]
def get_hr_scheduler_from_infotext(d: dict):
return get_hr_sampler_and_scheduler(d)[1]
@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
name = sampler_name or default_sampler.name
for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break
sampler = all_samplers_map.get(name, default_sampler)
# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]
return sampler.name, found_scheduler.label
set_samplers() set_samplers()

View File

@ -79,7 +79,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
steps += 1 if discard_next_to_last_sigma else 0 steps += 1 if discard_next_to_last_sigma else 0
scheduler_name = p.scheduler or 'Automatic' scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
if scheduler_name == 'Automatic': if scheduler_name == 'Automatic':
scheduler_name = self.config.options.get('scheduler', None) scheduler_name = self.config.options.get('scheduler', None)
@ -95,8 +95,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
else: else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
if scheduler.label != 'Automatic': if scheduler.label != 'Automatic' and not p.is_hr_pass:
p.extra_generation_params["Schedule type"] = scheduler.label p.extra_generation_params["Schedule type"] = scheduler.label
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
p.extra_generation_params["Hires schedule type"] = scheduler.label
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min: if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
sigmas_kwargs['sigma_min'] = opts.sigma_min sigmas_kwargs['sigma_min'] = opts.sigma_min

View File

@ -11,7 +11,7 @@ from PIL import Image
import gradio as gr import gradio as gr
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
if force_enable_hr: if force_enable_hr:
@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
hr_resize_y=hr_resize_y, hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name, hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt, hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings, override_settings=override_settings,

View File

@ -322,10 +322,11 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80): with gr.Column(scale=80):
@ -394,6 +395,7 @@ def create_ui():
hr_resize_y, hr_resize_y,
hr_checkpoint_name, hr_checkpoint_name,
hr_sampler_name, hr_sampler_name,
hr_scheduler,
hr_prompt, hr_prompt,
hr_negative_prompt, hr_negative_prompt,
override_settings, override_settings,
@ -456,8 +458,9 @@ def create_ui():
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),