From 9aa9e980a9a2846755b72482e8b83b12cbc3ca0f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 24 Mar 2024 11:00:16 +0300 Subject: [PATCH] support scheduler selection in hires fix --- modules/infotext_utils.py | 3 ++ modules/processing.py | 6 +++ modules/processing_scripts/sampler.py | 38 +---------------- modules/sd_samplers.py | 60 ++++++++++++++++++++++++++- modules/sd_samplers_kdiffusion.py | 6 ++- modules/txt2img.py | 3 +- modules/ui.py | 11 +++-- 7 files changed, 83 insertions(+), 44 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 723cb1f82..1c91d076d 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -314,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires schedule type" not in res: + res["Hires schedule type"] = "Use same scheduler" + if "Hires checkpoint" not in res: res["Hires checkpoint"] = "Use same checkpoint" diff --git a/modules/processing.py b/modules/processing.py index ac5541418..2baca4f5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1115,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_resize_y: int = 0 hr_checkpoint_name: str = None hr_sampler_name: str = None + hr_scheduler: str = None hr_prompt: str = '' hr_negative_prompt: str = '' force_task_id: str = None @@ -1203,6 +1204,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name + self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py + + if self.hr_scheduler is None: + self.hr_scheduler = self.scheduler + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and self.latent_scale_mode is None: if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): diff --git a/modules/processing_scripts/sampler.py b/modules/processing_scripts/sampler.py index 83b952550..5d50a162c 100644 --- a/modules/processing_scripts/sampler.py +++ b/modules/processing_scripts/sampler.py @@ -1,44 +1,10 @@ import gradio as gr -import functools from modules import scripts, sd_samplers, sd_schedulers, shared from modules.infotext_utils import PasteField from modules.ui_components import FormRow, FormGroup -def get_sampler_from_infotext(d: dict): - return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0] - - -def get_scheduler_from_infotext(d: dict): - return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1] - - -@functools.cache -def get_sampler_and_scheduler(sampler_name, scheduler_name): - default_sampler = sd_samplers.samplers[0] - found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0]) - - name = sampler_name or default_sampler.name - - for scheduler in sd_schedulers.schedulers: - name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])] - - for name_option in name_options: - if name.endswith(" " + name_option): - found_scheduler = scheduler - name = name[0:-(len(name_option) + 1)] - break - - sampler = sd_samplers.all_samplers_map.get(name, default_sampler) - - # revert back to Automatic if it's the default scheduler for the selected sampler - if sampler.options.get('scheduler', None) == found_scheduler.name: - found_scheduler = sd_schedulers.schedulers[0] - - return sampler.name, found_scheduler.label - - class ScriptSampler(scripts.ScriptBuiltinUI): section = "sampler" @@ -67,8 +33,8 @@ class ScriptSampler(scripts.ScriptBuiltinUI): self.infotext_fields = [ PasteField(self.steps, "Steps", api="steps"), - PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"), - PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"), + PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"), + PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"), ] return self.steps, self.sampler_name, self.scheduler diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 1f50297e6..6b7b84b6d 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,6 +1,8 @@ from __future__ import annotations -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common +import functools + +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers # imports for functions that previously were here and are used by other modules samples_to_image_grid = sd_samplers_common.samples_to_image_grid @@ -64,4 +66,60 @@ def visible_samplers(): return [x for x in samplers if x.name not in samplers_hidden] +def get_sampler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0] + + +def get_scheduler_from_infotext(d: dict): + return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1] + + +def get_hr_sampler_and_scheduler(d: dict): + hr_sampler = d.get("Hires sampler", "Use same sampler") + sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler + + hr_scheduler = d.get("Hires schedule type", "Use same scheduler") + scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler + + sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler) + + sampler = sampler if sampler != d.get("Sampler") else "Use same sampler" + scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler" + + return sampler, scheduler + + +def get_hr_sampler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[0] + + +def get_hr_scheduler_from_infotext(d: dict): + return get_hr_sampler_and_scheduler(d)[1] + + +@functools.cache +def get_sampler_and_scheduler(sampler_name, scheduler_name): + default_sampler = samplers[0] + found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0]) + + name = sampler_name or default_sampler.name + + for scheduler in sd_schedulers.schedulers: + name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])] + + for name_option in name_options: + if name.endswith(" " + name_option): + found_scheduler = scheduler + name = name[0:-(len(name_option) + 1)] + break + + sampler = all_samplers_map.get(name, default_sampler) + + # revert back to Automatic if it's the default scheduler for the selected sampler + if sampler.options.get('scheduler', None) == found_scheduler.name: + found_scheduler = sd_schedulers.schedulers[0] + + return sampler.name, found_scheduler.label + + set_samplers() diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index d053e48b1..b45f85b07 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -79,7 +79,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): steps += 1 if discard_next_to_last_sigma else 0 - scheduler_name = p.scheduler or 'Automatic' + scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic' if scheduler_name == 'Automatic': scheduler_name = self.config.options.get('scheduler', None) @@ -95,8 +95,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler): else: sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max} - if scheduler.label != 'Automatic': + if scheduler.label != 'Automatic' and not p.is_hr_pass: p.extra_generation_params["Schedule type"] = scheduler.label + elif scheduler.label != p.extra_generation_params.get("Schedule type"): + p.extra_generation_params["Hires schedule type"] = scheduler.label if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min: sigmas_kwargs['sigma_min'] = opts.sigma_min diff --git a/modules/txt2img.py b/modules/txt2img.py index 53fa0099d..6f20253ae 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -11,7 +11,7 @@ from PIL import Image import gradio as gr -def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) if force_enable_hr: @@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne hr_resize_y=hr_resize_y, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name, + hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, diff --git a/modules/ui.py b/modules/ui.py index c964d5e22..9b138e0aa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -322,10 +322,11 @@ def create_ui(): with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with gr.Column(scale=80): @@ -394,6 +395,7 @@ def create_ui(): hr_resize_y, hr_checkpoint_name, hr_sampler_name, + hr_scheduler, hr_prompt, hr_negative_prompt, override_settings, @@ -456,8 +458,9 @@ def create_ui(): PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), - PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), - PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"), + PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()), PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),