do not break progressbar on non-job actions

This commit is contained in:
Andray 2024-07-12 16:08:36 +04:00
parent 1da4907927
commit 589dda3cf2
5 changed files with 30 additions and 21 deletions

View File

@ -47,6 +47,22 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, **kwargs):
try:
res = func(*args, **kwargs)
finally:
shared.state.skipped = False
shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0
shared.state.job = ""
return res
return wrap_gradio_call_no_job(f, extra_outputs, add_stats)
def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):
@wraps(func) @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
@ -66,9 +82,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
errors.report(f"{message}\n{arg_str}", exc_info=True) errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None: if extra_outputs_array is None:
extra_outputs_array = [None, ''] extra_outputs_array = [None, '']
@ -77,11 +90,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
devices.torch_gc() devices.torch_gc()
shared.state.skipped = False
shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0
if not add_stats: if not add_stats:
return tuple(res) return tuple(res)
@ -123,3 +131,4 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res) return tuple(res)
return f return f

View File

@ -10,7 +10,7 @@ import gradio as gr
import gradio.utils import gradio.utils
import numpy as np import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401 from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401
from modules import gradio_extensons, sd_schedulers # noqa: F401 from modules import gradio_extensons, sd_schedulers # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
@ -889,7 +889,7 @@ def create_ui():
)) ))
image.change( image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo), fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo),
inputs=[image], inputs=[image],
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )

View File

@ -228,7 +228,7 @@ def create_output_panel(tabname, outdir, toprow=None):
) )
save.click( save.click(
fn=call_queue.wrap_gradio_call(save_files), fn=call_queue.wrap_gradio_call_no_job(save_files),
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[ inputs=[
res.generation_info, res.generation_info,
@ -244,7 +244,7 @@ def create_output_panel(tabname, outdir, toprow=None):
) )
save_zip.click( save_zip.click(
fn=call_queue.wrap_gradio_call(save_files), fn=call_queue.wrap_gradio_call_no_job(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[ inputs=[
res.generation_info, res.generation_info,

View File

@ -624,37 +624,37 @@ def create_ui():
) )
install_extension_button.click( install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, extensions_table, install_result], outputs=[available_extensions_table, extensions_table, install_result],
) )
search_extensions_text.change( search_extensions_text.change(
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(search_extensions, extra_outputs=[gr.update()]),
inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column], inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column],
outputs=[available_extensions_table, install_result], outputs=[available_extensions_table, install_result],
) )
selected_tags.change( selected_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result] outputs=[available_extensions_table, install_result]
) )
showing_type.change( showing_type.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result] outputs=[available_extensions_table, install_result]
) )
filtering_type.change( filtering_type.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result] outputs=[available_extensions_table, install_result]
) )
sort_column.change( sort_column.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text], inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result] outputs=[available_extensions_table, install_result]
) )
@ -667,7 +667,7 @@ def create_ui():
install_result = gr.HTML(elem_id="extension_install_result") install_result = gr.HTML(elem_id="extension_install_result")
install_button.click( install_button.click(
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), fn=modules.ui.wrap_gradio_call_no_job(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch], inputs=[install_dirname, install_url, install_branch],
outputs=[install_url, extensions_table, install_result], outputs=[install_url, extensions_table, install_result],
) )

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call_no_job
from modules.options import options_section from modules.options import options_section
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
@ -295,7 +295,7 @@ class UiSettings:
def add_functionality(self, demo): def add_functionality(self, demo):
self.submit.click( self.submit.click(
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), fn=wrap_gradio_call_no_job(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
inputs=self.components, inputs=self.components,
outputs=[self.text_settings, self.result], outputs=[self.text_settings, self.result],
) )