pass samplers from UI by name, make it possible to use a sampler from infotext even if it's hidden in the dropdown

This commit is contained in:
AUTOMATIC1111 2023-08-08 21:28:34 +03:00
parent ae1bde1aa1
commit 70c63c1208
4 changed files with 30 additions and 26 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr import gradio as gr
from modules import sd_samplers, images as imgutil from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p) process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View File

@ -12,6 +12,7 @@ all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
samplers_map = {} samplers_map = {}
samplers_hidden = {}
def find_sampler_config(name): def find_sampler_config(name):
@ -38,11 +39,11 @@ def create_sampler(name, model):
def set_samplers(): def set_samplers():
global samplers, samplers_for_img2img global samplers, samplers_for_img2img, samplers_hidden
hidden = set(shared.opts.hide_samplers) samplers_hidden = set(shared.opts.hide_samplers)
samplers = [x for x in all_samplers if x.name not in hidden] samplers = all_samplers
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = all_samplers
samplers_map.clear() samplers_map.clear()
for sampler in all_samplers: for sampler in all_samplers:
@ -51,4 +52,8 @@ def set_samplers():
samplers_map[alias.lower()] = sampler.name samplers_map[alias.lower()] = sampler.name
def visible_sampler_names():
return [x.name for x in samplers if x.name not in samplers_hidden]
set_samplers() set_samplers()

View File

@ -1,7 +1,7 @@
from contextlib import closing from contextlib import closing
import modules.scripts import modules.scripts
from modules import sd_samplers, processing from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_name=sd_samplers.samplers[sampler_index].name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_x=hr_resize_x, hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y, hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_sampler_name=hr_sampler_name,
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt, hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings, override_settings=override_settings,

View File

@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401 from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
@ -29,7 +29,6 @@ import modules.shared as shared
import modules.images import modules.images
from modules import prompt_parser from modules import prompt_parser
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component create_setting_component = ui_settings.create_setting_component
@ -360,14 +359,14 @@ def create_output_panel(tabname, outdir):
def create_sampler_and_steps_selection(choices, tabname): def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown: if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"): with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else: else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"): with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
return steps, sampler_index return steps, sampler_name
def ordered_ui_categories(): def ordered_ui_categories():
@ -414,7 +413,7 @@ def create_ui():
for category in ordered_ui_categories(): for category in ordered_ui_categories():
if category == "sampler": if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions": elif category == "dimensions":
with FormRow(): with FormRow():
@ -460,7 +459,7 @@ def create_ui():
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80): with gr.Column(scale=80):
@ -520,7 +519,7 @@ def create_ui():
toprow.negative_prompt, toprow.negative_prompt,
toprow.ui_styles.dropdown, toprow.ui_styles.dropdown,
steps, steps,
sampler_index, sampler_name,
restore_faces, restore_faces,
tiling, tiling,
batch_count, batch_count,
@ -538,7 +537,7 @@ def create_ui():
hr_resize_x, hr_resize_x,
hr_resize_y, hr_resize_y,
hr_checkpoint_name, hr_checkpoint_name,
hr_sampler_index, hr_sampler_name,
hr_prompt, hr_prompt,
hr_negative_prompt, hr_negative_prompt,
override_settings, override_settings,
@ -583,7 +582,7 @@ def create_ui():
(toprow.prompt, "Prompt"), (toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"), (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"), (steps, "Steps"),
(sampler_index, "Sampler"), (sampler_name, "Sampler"),
(restore_faces, "Face restoration"), (restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(seed, "Seed"), (seed, "Seed"),
@ -605,7 +604,7 @@ def create_ui():
(hr_resize_x, "Hires resize-1"), (hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"), (hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"), (hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_index, "Hires sampler"), (hr_sampler_name, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"), (hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"), (hr_negative_prompt, "Hires negative prompt"),
@ -621,7 +620,7 @@ def create_ui():
toprow.prompt, toprow.prompt,
toprow.negative_prompt, toprow.negative_prompt,
steps, steps,
sampler_index, sampler_name,
cfg_scale, cfg_scale,
seed, seed,
width, width,
@ -744,7 +743,7 @@ def create_ui():
for category in ordered_ui_categories(): for category in ordered_ui_categories():
if category == "sampler": if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
elif category == "dimensions": elif category == "dimensions":
with FormRow(): with FormRow():
@ -876,7 +875,7 @@ def create_ui():
init_img_inpaint, init_img_inpaint,
init_mask_inpaint, init_mask_inpaint,
steps, steps,
sampler_index, sampler_name,
mask_blur, mask_blur,
mask_alpha, mask_alpha,
inpainting_fill, inpainting_fill,
@ -972,7 +971,7 @@ def create_ui():
(toprow.prompt, "Prompt"), (toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"), (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"), (steps, "Steps"),
(sampler_index, "Sampler"), (sampler_name, "Sampler"),
(restore_faces, "Face restoration"), (restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"), (image_cfg_scale, "Image CFG scale"),