From df02498d03e4296b7d7581aff69571a49be1d27a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 May 2023 22:40:09 +0300 Subject: [PATCH] add an option to show selected setting in main txt2img/img2img UI split some code from ui.py into ui_settings.py ui_gradio_edxtensions.py add before_process callback for scripts add ability for alwayson scripts to specify section and let user reorder those sections --- .../scripts/extra_options_section.py | 48 +++ modules/processing.py | 6 +- modules/scripts.py | 144 ++++--- modules/shared_items.py | 10 + modules/ui.py | 351 ++---------------- modules/ui_common.py | 23 ++ modules/ui_gradio_extensions.py | 69 ++++ modules/ui_settings.py | 263 +++++++++++++ 8 files changed, 526 insertions(+), 388 deletions(-) create mode 100644 extensions-builtin/extra-options-section/scripts/extra_options_section.py create mode 100644 modules/ui_gradio_extensions.py create mode 100644 modules/ui_settings.py diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py new file mode 100644 index 000000000..17f841844 --- /dev/null +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -0,0 +1,48 @@ +import gradio as gr +from modules import scripts, shared, ui_components, ui_settings +from modules.ui_components import FormColumn + + +class ExtraOptionsSection(scripts.Script): + section = "extra_options" + + def __init__(self): + self.comps = None + self.setting_names = None + + def title(self): + return "Extra options" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.comps = [] + self.setting_names = [] + + with gr.Blocks() as interface: + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + for setting_name in shared.opts.extra_options: + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + def get_settings_values(): + return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + + interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) + + return self.comps + + def before_process(self, p, *args): + for name, value in zip(self.setting_names, args): + if name not in p.override_settings: + p.override_settings[name] = value + + +shared.options_templates.update(shared.options_section(('ui', "User interface"), { + "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), + "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") +})) diff --git a/modules/processing.py b/modules/processing.py index f628d88bd..baa9b2782 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -588,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter def process_images(p: StableDiffusionProcessing) -> Processed: + if p.scripts is not None: + p.scripts.before_process(p) + stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: + override_checkpoint = p.override_settings.get('sd_model_checkpoint') + if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() diff --git a/modules/scripts.py b/modules/scripts.py index 0970f38e0..b901862dc 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -19,6 +19,9 @@ class Script: name = None """script's internal name derived from title""" + section = None + """name of UI section that the script's controls will be placed into""" + filename = None args_from = None args_to = None @@ -81,6 +84,15 @@ class Script: pass + def before_process(self, p, *args): + """ + This function is called very early before processing begins for AlwaysVisible scripts. + You can modify the processing object (p) here, inject hooks, etc. + args contains all values returned by components from ui() + """ + + pass + def process(self, p, *args): """ This function is called before processing begins for AlwaysVisible scripts. @@ -293,6 +305,7 @@ class ScriptRunner: self.titles = [] self.infotext_fields = [] self.paste_field_names = [] + self.inputs = [None] def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -320,69 +333,73 @@ class ScriptRunner: self.scripts.append(script) self.selectable_scripts.append(script) - def setup_ui(self): + def create_script_ui(self, script): import modules.api.models as api_models + script.args_from = len(self.inputs) + script.args_to = len(self.inputs) + + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + + if controls is None: + return + + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] + + for control in controls: + control.custom_script_source = os.path.basename(script.filename) + + arg_info = api_models.ScriptArg(label=control.label or "") + + for field in ("value", "minimum", "maximum", "step", "choices"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) + + api_args.append(arg_info) + + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields + + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names + + self.inputs += controls + script.args_to = len(self.inputs) + + def setup_ui_for_section(self, section, scriptlist=None): + if scriptlist is None: + scriptlist = self.alwayson_scripts + + for script in scriptlist: + if script.alwayson and script.section != section: + continue + + with gr.Group(visible=script.alwayson) as group: + self.create_script_ui(script) + + script.group = group + + def prepare_ui(self): + self.inputs = [None] + + def setup_ui(self): self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] - inputs = [None] - inputs_alwayson = [True] - - def create_script_ui(script, inputs, inputs_alwayson): - script.args_from = len(inputs) - script.args_to = len(inputs) - - controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) - - if controls is None: - return - - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] - - for control in controls: - control.custom_script_source = os.path.basename(script.filename) - - arg_info = api_models.ScriptArg(label=control.label or "") - - for field in ("value", "minimum", "maximum", "step", "choices"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) - - api_args.append(arg_info) - - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) - - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields - - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names - - inputs += controls - inputs_alwayson += [script.alwayson for _ in controls] - script.args_to = len(inputs) - - for script in self.alwayson_scripts: - with gr.Group() as group: - create_script_ui(script, inputs, inputs_alwayson) - - script.group = group + self.setup_ui_for_section(None) dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - inputs[0] = dropdown + self.inputs[0] = dropdown - for script in self.selectable_scripts: - with gr.Group(visible=False) as group: - create_script_ui(script, inputs, inputs_alwayson) - - script.group = group + self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None @@ -407,6 +424,7 @@ class ScriptRunner: ) self.script_load_ctr = 0 + def onload_script_visibility(params): title = params.get('Script', None) if title: @@ -417,10 +435,10 @@ class ScriptRunner: else: return gr.update(visible=False) - self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) - self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) + self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) - return inputs + return self.inputs def run(self, p, *args): script_index = args[0] @@ -440,6 +458,14 @@ class ScriptRunner: return processed + def before_process(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process(p, *script_args) + except Exception: + errors.report(f"Error running before_process: {script.filename}", exc_info=True) + def process(self, p): for script in self.alwayson_scripts: try: diff --git a/modules/shared_items.py b/modules/shared_items.py index 27bceb181..89792e88a 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -55,5 +55,15 @@ ui_reorder_categories_builtin_items = [ def ui_reorder_categories(): + from modules import scripts + yield from ui_reorder_categories_builtin_items + + sections = {} + for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: + if isinstance(script.section, str): + sections[script.section] = 1 + + yield from sections + yield "scripts" diff --git a/modules/ui.py b/modules/ui.py index 35563669d..4e0cf7763 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -6,15 +6,17 @@ from functools import reduce import warnings import gradio as gr -import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path, data_path +from modules.paths import script_path +from modules.ui_common import create_refresh_button +from modules.ui_gradio_extensions import reload_javascript + from modules.shared import opts, cmd_opts @@ -34,6 +36,8 @@ import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text import modules.extras +create_setting_component = ui_settings.create_setting_component + warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI @@ -366,25 +370,6 @@ def apply_setting(key, value): return getattr(opts, key) -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - def create_output_panel(tabname, outdir): return ui_common.create_output_panel(tabname, outdir) @@ -409,16 +394,6 @@ def ordered_ui_categories(): yield category -def get_value_for_setting(key): - value = getattr(opts, key) - - info = opts.data_labels[key] - args = info.component_args() if callable(info.component_args) else info.component_args or {} - args = {k: v for k, v in args.items() if k not in {'precision'}} - - return gr.update(value=value, **args) - - def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) @@ -454,6 +429,8 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): + modules.scripts.scripts_txt2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") @@ -522,6 +499,9 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + else: + modules.scripts.scripts_txt2img.setup_ui_for_section(category) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] for component in hr_resolution_preview_inputs: @@ -778,6 +758,8 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + modules.scripts.scripts_img2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") @@ -887,6 +869,8 @@ def create_ui(): inputs=[], outputs=[inpaint_controls, mask_alpha], ) + else: + modules.scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1460,195 +1444,10 @@ def create_ui(): outputs=[], ) - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {t} for key {key}') - - elem_id = f"setting_{key}" - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) - components = [] - component_dict = {} - shared.settings_components = component_dict - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return get_value_for_setting(key), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = opts.quicksettings_list - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - current_row = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - gr.Group() - current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) - current_tab.__enter__() - current_row = gr.Column(variant='compact') - current_row.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): - loadsave.create_ui() - - with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") - - with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): - gr.HTML(shared.html("licenses.html"), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - - def unload_sd_weights(): - modules.sd_models.unload_model_weights() - - def reload_sd_weights(): - modules.sd_models.reload_model_weights() - - unload_sd_model.click( - fn=unload_sd_weights, - inputs=[], - outputs=[] - ) - - reload_sd_model.click( - fn=reload_sd_weights, - inputs=[], - outputs=[] - ) - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - restart_gradio.click( - fn=shared.state.request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) + settings = ui_settings.UiSettings() + settings.create_ui(loadsave, dummy_component) interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1660,7 +1459,7 @@ def create_ui(): ] interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(settings.interface, "Settings", "settings")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] @@ -1670,10 +1469,7 @@ def create_ui(): shared.tab_names.append(label) with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings", variant="compact"): - for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component + settings.add_quicksettings() parameters_copypaste.connect_paste_params_buttons() @@ -1704,49 +1500,12 @@ def create_ui(): footer = footer.format(versions=versions_html()) gr.HTML(footer, elem_id="footer") - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for _i, k, _item in quicksettings_list: - component = component_dict[k] - info = opts.data_labels[k] - - change_handler = component.release if hasattr(component, 'release') else component.change - change_handler( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - show_progress=info.refresh is not None, - ) + settings.add_functionality(demo) update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") - text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) - button_set_checkpoint.click( - fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), - _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[component_dict['sd_model_checkpoint'], dummy_component], - outputs=[component_dict['sd_model_checkpoint'], text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [get_value_for_setting(key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - queue=False, - ) - def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) @@ -1779,7 +1538,7 @@ def create_ui(): primary_model_name, secondary_model_name, tertiary_model_name, - component_dict['sd_model_checkpoint'], + settings.component_dict['sd_model_checkpoint'], modelmerger_result, ] ) @@ -1793,70 +1552,6 @@ def create_ui(): return demo -def webpath(fn): - if fn.startswith(script_path): - web_path = os.path.relpath(fn, script_path).replace('\\', '/') - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' - - -def javascript_html(): - # Ensure localization is in `window` before scripts - head = f'\n' - - script_js = os.path.join(script_path, "script.js") - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".mjs"): - head += f'\n' - - if cmd_opts.theme: - head += f'\n' - - return head - - -def css_html(): - head = "" - - def stylesheet(fn): - return f'' - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - head += stylesheet(cssfile) - - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) - - return head - - -def reload_javascript(): - js = javascript_html() - css = css_html() - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace(b'', f'{js}'.encode("utf8")) - res.body = res.body.replace(b'', f'{css}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - def versions_html(): import torch import launch diff --git a/modules/ui_common.py b/modules/ui_common.py index 5a9204a4d..57c2d0ade 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -10,8 +10,11 @@ import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text import modules.images +from modules.ui_components import ToolButton + folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 def update_generation_info(generation_info, html_info, img_index): @@ -216,3 +219,23 @@ Requested path was: {f} )) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py new file mode 100644 index 000000000..b824b1137 --- /dev/null +++ b/modules/ui_gradio_extensions.py @@ -0,0 +1,69 @@ +import os +import gradio as gr + +from modules import localization, shared, scripts +from modules.paths import script_path, data_path + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + # Ensure localization is in `window` before scripts + head = f'\n' + + script_js = os.path.join(script_path, "script.js") + head += f'\n' + + for script in scripts.list_scripts("javascript", ".js"): + head += f'\n' + + for script in scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + + if shared.cmd_opts.theme: + head += f'\n' + + return head + + +def css_html(): + head = "" + + def stylesheet(fn): + return f'' + + for cssfile in scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + head += stylesheet(cssfile) + + if os.path.exists(os.path.join(data_path, "user.css")): + head += stylesheet(os.path.join(data_path, "user.css")) + + return head + + +def reload_javascript(): + js = javascript_html() + css = css_html() + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse diff --git a/modules/ui_settings.py b/modules/ui_settings.py new file mode 100644 index 000000000..7874298e8 --- /dev/null +++ b/modules/ui_settings.py @@ -0,0 +1,263 @@ +import gradio as gr + +from modules import ui_common, shared, script_callbacks, scripts, sd_models +from modules.call_queue import wrap_gradio_call +from modules.shared import opts +from modules.ui_components import FormRow +from modules.ui_gradio_extensions import reload_javascript + + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + +def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {t} for key {key}') + + elem_id = f"setting_{key}" + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + +class UiSettings: + submit = None + result = None + interface = None + components = None + component_dict = None + dummy_component = None + quicksettings_list = None + quicksettings_names = None + text_settings = None + + def run_settings(self, *args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + if comp == self.dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(self, value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + def create_ui(self, loadsave, dummy_component): + self.components = [] + self.component_dict = {} + self.dummy_component = dummy_component + + shared.settings_components = self.component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + self.result = gr.HTML(elem_id="settings_result") + + self.quicksettings_names = opts.quicksettings_list + self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'} + + self.quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) + current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() + + previous_section = item.section + + if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings: + self.quicksettings_list.append((i, k, item)) + self.components.append(dummy_component) + elif section_must_be_skipped: + self.components.append(dummy_component) + else: + component = create_setting_component(k) + self.component_dict[k] = component + self.components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): + loadsave.create_ui() + + with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + with gr.Row(): + unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + + with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + + unload_sd_model.click( + fn=sd_models.unload_model_weights, + inputs=[], + outputs=[] + ) + + reload_sd_model.click( + fn=sd_models.reload_model_weights, + inputs=[], + outputs=[] + ) + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + restart_gradio.click( + fn=shared.state.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + self.interface = settings_interface + + def add_quicksettings(self): + with gr.Row(elem_id="quicksettings", variant="compact"): + for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + self.component_dict[k] = component + + def add_functionality(self, demo): + self.submit.click( + fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + inputs=self.components, + outputs=[self.text_settings, self.result], + ) + + for _i, k, _item in self.quicksettings_list: + component = self.component_dict[k] + info = opts.data_labels[k] + + change_handler = component.release if hasattr(component, 'release') else component.change + change_handler( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], + outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[self.component_dict[k] for k in component_keys], + queue=False, + )