From 589dda3cf2954beaeef65928c063e2bb5c680209 Mon Sep 17 00:00:00 2001 From: Andray Date: Fri, 12 Jul 2024 16:08:36 +0400 Subject: [PATCH] do not break progressbar on non-job actions --- modules/call_queue.py | 25 +++++++++++++++++-------- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extensions.py | 14 +++++++------- modules/ui_settings.py | 4 ++-- 5 files changed, 30 insertions(+), 21 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index d22c23b31..555c35312 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -47,6 +47,22 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + @wraps(func) + def f(*args, **kwargs): + try: + res = func(*args, **kwargs) + finally: + shared.state.skipped = False + shared.state.interrupted = False + shared.state.stopping_generation = False + shared.state.job_count = 0 + shared.state.job = "" + return res + + return wrap_gradio_call_no_job(f, extra_outputs, add_stats) + + +def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False): @wraps(func) def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats @@ -66,9 +82,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" errors.report(f"{message}\n{arg_str}", exc_info=True) - shared.state.job = "" - shared.state.job_count = 0 - if extra_outputs_array is None: extra_outputs_array = [None, ''] @@ -77,11 +90,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): devices.torch_gc() - shared.state.skipped = False - shared.state.interrupted = False - shared.state.stopping_generation = False - shared.state.job_count = 0 - if not add_stats: return tuple(res) @@ -123,3 +131,4 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f + diff --git a/modules/ui.py b/modules/ui.py index 5af34ecb0..8edce620f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,7 +10,7 @@ import gradio as gr import gradio.utils import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401 from modules import gradio_extensons, sd_schedulers # noqa: F401 from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils @@ -889,7 +889,7 @@ def create_ui(): )) image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), + fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) diff --git a/modules/ui_common.py b/modules/ui_common.py index 48992a3c1..fb3967701 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -228,7 +228,7 @@ def create_output_panel(tabname, outdir, toprow=None): ) save.click( - fn=call_queue.wrap_gradio_call(save_files), + fn=call_queue.wrap_gradio_call_no_job(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ res.generation_info, @@ -244,7 +244,7 @@ def create_output_panel(tabname, outdir, toprow=None): ) save_zip.click( - fn=call_queue.wrap_gradio_call(save_files), + fn=call_queue.wrap_gradio_call_no_job(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ res.generation_info, diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6b6403f23..23aff7096 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -624,37 +624,37 @@ def create_ui(): ) install_extension_button.click( - fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, extensions_table, install_result], ) search_extensions_text.change( - fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(search_extensions, extra_outputs=[gr.update()]), inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column], outputs=[available_extensions_table, install_result], ) selected_tags.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) showing_type.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) filtering_type.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) sort_column.change( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], outputs=[available_extensions_table, install_result] ) @@ -667,7 +667,7 @@ def create_ui(): install_result = gr.HTML(elem_id="extension_install_result") install_button.click( - fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call_no_job(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), inputs=[install_dirname, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 087b91f3b..e53ad50f8 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,7 +1,7 @@ import gradio as gr from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items -from modules.call_queue import wrap_gradio_call +from modules.call_queue import wrap_gradio_call_no_job from modules.options import options_section from modules.shared import opts from modules.ui_components import FormRow @@ -295,7 +295,7 @@ class UiSettings: def add_functionality(self, demo): self.submit.click( - fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + fn=wrap_gradio_call_no_job(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), inputs=self.components, outputs=[self.text_settings, self.result], )