diff --git a/modules/images.py b/modules/images.py index 26d5b7a95..8737ccff0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -524,6 +524,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: image.save(fullfn, quality=opts.jpeg_quality) + image.already_saved_as = fullfn + target_side_length = 4000 oversize = image.width > target_side_length or image.height > target_side_length if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024): diff --git a/modules/shared.py b/modules/shared.py index 8fb1387a9..af975f54e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,6 +16,9 @@ import modules.devices as devices from modules import localization, sd_vae, extensions, script_loading from modules.paths import models_path, script_path, sd_path + +demo = None + sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() @@ -292,6 +295,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { diff --git a/modules/ui.py b/modules/ui.py index c8b8fecd0..ea925c40e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -157,22 +157,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def save_pil_to_file(pil_image, dir=None): - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in pil_image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True - - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) - return file_obj - - -# override save to file function so that it also writes PNG info -gr.processing_utils.save_pil_to_file = save_pil_to_file - def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py new file mode 100644 index 000000000..9c6d3a9d3 --- /dev/null +++ b/modules/ui_tempdir.py @@ -0,0 +1,62 @@ +import os +import tempfile +from collections import namedtuple + +import gradio as gr + +from PIL import PngImagePlugin + +from modules import shared + + +Savedfile = namedtuple("Savedfile", ["name"]) + + +def save_pil_to_file(pil_image, dir=None): + already_saved_as = getattr(pil_image, 'already_saved_as', None) + if already_saved_as: + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} + file_obj = Savedfile(already_saved_as) + return file_obj + + if shared.opts.temp_dir != "": + dir = shared.opts.temp_dir + + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def on_tmpdir_changed(): + if shared.opts.temp_dir == "" or shared.demo is None: + return + + os.makedirs(shared.opts.temp_dir, exist_ok=True) + + shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} + + +def cleanup_tmpdr(): + temp_dir = shared.opts.temp_dir + if temp_dir == "" or not os.path.isdir(temp_dir): + return + + for root, dirs, files in os.walk(temp_dir, topdown=False): + for name in files: + _, extension = os.path.splitext(name) + if extension != ".png": + continue + + filename = os.path.join(root, name) + os.remove(filename) diff --git a/webui.py b/webui.py index 23215d1e6..6b79dc558 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import shared, devices, sd_samplers, upscaler, extensions, localization +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -31,12 +31,14 @@ from modules import modelloader from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork + queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None + def wrap_queued_call(func): def f(*args, **kwargs): with queue_lock: @@ -87,6 +89,7 @@ def initialize(): shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -149,9 +152,12 @@ def webui(): initialize() while 1: - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + if shared.opts.clean_temp_dir_at_start: + ui_tempdir.cleanup_tmpdr() - app, local_url, share_url = demo.launch( + shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, @@ -178,9 +184,9 @@ def webui(): if launch_api: create_api(app) - modules.script_callbacks.app_started_callback(demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) - wait_on_server(demo) + wait_on_server(shared.demo) sd_samplers.set_samplers()