From 065364445d4ea1ddec44c3f87d1b6b8acda592a6 Mon Sep 17 00:00:00 2001 From: EternalNooblet Date: Fri, 7 Oct 2022 15:25:01 -0400 Subject: [PATCH 01/72] added a flag to run as root if needed --- webui.sh | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 05ca497d2..41649b9a4 100755 --- a/webui.sh +++ b/webui.sh @@ -3,6 +3,7 @@ # Please do not make any changes to this file, # # change the variables in webui-user.sh instead # ################################################# + # Read variables from webui-user.sh # shellcheck source=/dev/null if [[ -f webui-user.sh ]] @@ -46,6 +47,17 @@ then LAUNCH_SCRIPT="launch.py" fi +# this script cannot be run as root by default +can_run_as_root=0 + +# read any command line flags to the webui.sh script +while getopts "f" flag +do + case ${flag} in + f) can_run_as_root=1;; + esac +done + # Disable sentry logging export ERROR_REPORTING=FALSE @@ -61,7 +73,7 @@ printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" printf "\n%s\n" "${delimiter}" # Do not run as root -if [[ $(id -u) -eq 0 ]] +if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]] then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" From a258fd60dbe2d68325339405a2aa72816d06d2fd Mon Sep 17 00:00:00 2001 From: Keavon Chambers Date: Mon, 7 Nov 2022 00:13:58 -0800 Subject: [PATCH 02/72] Add CORS-allow policy launch argument using regex --- modules/shared.py | 7 ++++--- webui.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index e8bacd3c9..55de286d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -81,12 +81,13 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") -parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) -parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None) +parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) diff --git a/webui.py b/webui.py index f4f1d74d1..066d94f71 100644 --- a/webui.py +++ b/webui.py @@ -107,8 +107,12 @@ def initialize(): def setup_cors(app): - if cmd_opts.cors_allow_origins: + if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + elif cmd_opts.cors_allow_origins: app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + elif cmd_opts.cors_allow_origins_regex: + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) def create_api(app): From c556d34523e8764bd66bf6a7bf97d06add420020 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Fri, 11 Nov 2022 08:54:51 +0700 Subject: [PATCH 03/72] Forcing HTTPS instead of HTTP for ngrok For security reason. --- modules/ngrok.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ngrok.py b/modules/ngrok.py index 5c5f349aa..25c53af85 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -8,7 +8,7 @@ def connect(token, port, region): auth_token=token, region=region ) try: - public_url = ngrok.connect(port, pyngrok_config=config).public_url + public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') From f4a488f585c09b420dc05199240e68f8fb74337f Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 20:12:31 -0500 Subject: [PATCH 04/72] Set device for facelib/facexlib and gfpgan * FaceXLib/FaceLib doesn't pass the device argument to RetinaFace but instead chooses one itself and sets it to a global - in order to use a device other than its internally chosen default it is necessary to manually replace the default value * The GFPGAN constructor needs the device argument to work with MPS or a CUDA device ID that differs from the default --- modules/codeformer_model.py | 3 +++ modules/gfpgan_model.py | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index e6d9fa4f4..ab40d842c 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -36,6 +36,7 @@ def setup_model(dirname): from basicsr.utils.download_util import load_file_from_url from basicsr.utils import imwrite, img2tensor, tensor2img from facelib.utils.face_restoration_helper import FaceRestoreHelper + from facelib.detection.retinaface import retinaface from modules.shared import cmd_opts net_class = CodeFormer @@ -65,6 +66,8 @@ def setup_model(dirname): net.load_state_dict(checkpoint) net.eval() + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) self.net = net diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a9452dce5..1e2dbc32b 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,7 +36,9 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) + if hasattr(facexlib.detection.retinaface, 'device'): + facexlib.detection.retinaface.device = devices.device_gfpgan + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) loaded_gfpgan_model = model return model From 007f4f7314eabd9cc3a2b0d11889de49ad3c682a Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 12 Nov 2022 15:12:15 +0300 Subject: [PATCH 05/72] Tests cleaned up --- launch.py | 5 ++++- test/server_poll.py | 7 ++++--- test/test_files/empty.pt | Bin 0 -> 431 bytes test/txt2img_test.py | 4 +++- test/utils_test.py | 18 +++++++++--------- 5 files changed, 20 insertions(+), 14 deletions(-) create mode 100644 test/test_files/empty.pt diff --git a/launch.py b/launch.py index 8e65676d3..6822a01de 100644 --- a/launch.py +++ b/launch.py @@ -229,6 +229,9 @@ def prepare_enviroment(): def tests(argv): if "--api" not in argv: argv.append("--api") + if "--ckpt" not in argv: + argv.append("--ckpt") + argv.append("./test/test_files/empty.pt") print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}") @@ -236,7 +239,7 @@ def tests(argv): proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr) import test.server_poll - test.server_poll.run_tests() + test.server_poll.run_tests(proc) print(f"Stopping Web UI process with id {proc.pid}") proc.kill() diff --git a/test/server_poll.py b/test/server_poll.py index eeefb7ebc..8e63b4502 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -3,7 +3,7 @@ import requests import time -def run_tests(): +def run_tests(proc): timeout_threshold = 240 start_time = time.time() while time.time()-start_time < timeout_threshold: @@ -11,8 +11,9 @@ def run_tests(): requests.head("http://localhost:7860/") break except requests.exceptions.ConnectionError: - pass - if time.time()-start_time < timeout_threshold: + if proc.poll() is not None: + break + if proc.poll() is None: suite = unittest.TestLoader().discover('', pattern='*_test.py') result = unittest.TextTestRunner(verbosity=2).run(suite) else: diff --git a/test/test_files/empty.pt b/test/test_files/empty.pt new file mode 100644 index 0000000000000000000000000000000000000000..c6ac59eb01fcb778290a85f12bdb7867de3dfdd1 GIT binary patch literal 431 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfvL8TK`+3Yo#WF) zrlV{?Q$RQXr>Xo5ws2F+Qj3Z+^Yh%CEYS=_u>n8Fm@ES21PVa+A-Zm4llf6}h5>mn-B6zdc(bwTKo!X`>%x_T-2>#o=xV6UB`6Kl#|~op WGC~AERDd@tC?tV;m>59nA!-3{+(-BT literal 0 HcmV?d00001 diff --git a/test/txt2img_test.py b/test/txt2img_test.py index 1936e07e2..ce7520858 100644 --- a/test/txt2img_test.py +++ b/test/txt2img_test.py @@ -53,13 +53,15 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["restore_faces"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - def test_txt2img_with_tiling_faces_performed(self): + def test_txt2img_with_tiling_performed(self): self.simple_txt2img["tiling"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) def test_txt2img_with_vanilla_sampler_performed(self): self.simple_txt2img["sampler_index"] = "PLMS" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + self.simple_txt2img["sampler_index"] = "DDIM" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) def test_txt2img_multiple_batches_performed(self): self.simple_txt2img["n_iter"] = 2 diff --git a/test/utils_test.py b/test/utils_test.py index 65d3d177d..be9e6bf8c 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -18,19 +18,19 @@ class UtilsTests(unittest.TestCase): def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) - def test_options_write(self): - response = requests.get(self.url_options) - self.assertEqual(response.status_code, 200) + # def test_options_write(self): + # response = requests.get(self.url_options) + # self.assertEqual(response.status_code, 200) - pre_value = response.json()["send_seed"] + # pre_value = response.json()["send_seed"] - self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) + # self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) - response = requests.get(self.url_options) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json()["send_seed"], not pre_value) + # response = requests.get(self.url_options) + # self.assertEqual(response.status_code, 200) + # self.assertEqual(response.json()["send_seed"], not pre_value) - requests.post(self.url_options, json={"send_seed": pre_value}) + # requests.post(self.url_options, json={"send_seed": pre_value}) def test_cmd_flags(self): self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) From d20dbe47e06de7f6c0e65242a04c9bb1410ef7cb Mon Sep 17 00:00:00 2001 From: Xu Cuijie <975114697@qq.com> Date: Sun, 13 Nov 2022 10:31:03 +0800 Subject: [PATCH 06/72] fix the model name error of Real-ESRGAN in the opts default value --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 6936cbe06..c46c29f77 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -299,7 +299,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), From 9a1aff645a4bea745145c57c96950fbd3fcca27c Mon Sep 17 00:00:00 2001 From: parasi Date: Sun, 13 Nov 2022 13:44:27 -0600 Subject: [PATCH 07/72] resolve [name] after resolving [filewords] in training --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index eb75c3769..06f271f92 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -97,13 +97,13 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') if shared.opts.tag_drop_out != 0: tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) + text = text.replace("[name]", self.placeholder_token) return text def __len__(self): From 93d6c0209ae55632b72751cf82740e32a0cd81bc Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:39:22 +0300 Subject: [PATCH 08/72] Tests separated for github-actions CI --- .github/workflows/run_tests.yaml | 26 ++++++++++++ launch.py | 36 ++++++++++------ test/advanced_features/__init__.py | 0 test/{ => advanced_features}/extras_test.py | 4 +- test/advanced_features/txt2img_test.py | 47 +++++++++++++++++++++ test/basic_features/__init__.py | 0 test/{ => basic_features}/img2img_test.py | 4 -- test/{ => basic_features}/txt2img_test.py | 4 -- test/{ => basic_features}/utils_test.py | 6 ++- test/server_poll.py | 6 ++- 10 files changed, 108 insertions(+), 25 deletions(-) create mode 100644 .github/workflows/run_tests.yaml create mode 100644 test/advanced_features/__init__.py rename test/{ => advanced_features}/extras_test.py (90%) create mode 100644 test/advanced_features/txt2img_test.py create mode 100644 test/basic_features/__init__.py rename test/{ => basic_features}/img2img_test.py (96%) rename test/{ => basic_features}/txt2img_test.py (97%) rename test/{ => basic_features}/utils_test.py (97%) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 000000000..a56a81103 --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,26 @@ +name: Run tests on CPU with empty model + +on: + - push + - pull_request + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: 3.10.6 + - uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Run tests + run: | + export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all" + python launch.py diff --git a/launch.py b/launch.py index 6822a01de..d0f502c2b 100644 --- a/launch.py +++ b/launch.py @@ -17,6 +17,19 @@ def extract_arg(args, name): return [x for x in args if x != name], name in args +def extract_opt(args, name): + opt = None + is_present = False + if name in args: + is_present = True + idx = args.index(name) + del args[idx] + if idx < len(args) and args[idx][0] != "-": + opt = args[idx] + del args[idx] + return args, is_present, opt + + def run(command, desc=None, errdesc=None, custom_env=None): if desc is not None: print(desc) @@ -151,12 +164,11 @@ def prepare_enviroment(): blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") sys.argv += shlex.split(commandline_args) - test_argv = [x for x in sys.argv if x != '--tests'] sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, update_check = extract_arg(sys.argv, '--update-check') - sys.argv, run_tests = extract_arg(sys.argv, '--tests') + sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') xformers = '--xformers' in sys.argv deepdanbooru = '--deepdanbooru' in sys.argv ngrok = '--ngrok' in sys.argv @@ -222,24 +234,24 @@ def prepare_enviroment(): exit(0) if run_tests: - tests(test_argv) + tests(test_dir) exit(0) -def tests(argv): - if "--api" not in argv: - argv.append("--api") - if "--ckpt" not in argv: - argv.append("--ckpt") - argv.append("./test/test_files/empty.pt") +def tests(test_dir): + if "--api" not in sys.argv: + sys.argv.append("--api") + if "--ckpt" not in sys.argv: + sys.argv.append("--ckpt") + sys.argv.append("./test/test_files/empty.pt") - print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}") + print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: - proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr) + proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) import test.server_poll - test.server_poll.run_tests(proc) + test.server_poll.run_tests(proc, test_dir) print(f"Stopping Web UI process with id {proc.pid}") proc.kill() diff --git a/test/advanced_features/__init__.py b/test/advanced_features/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/extras_test.py b/test/advanced_features/extras_test.py similarity index 90% rename from test/extras_test.py rename to test/advanced_features/extras_test.py index 9b8ce0f03..8763f8ed1 100644 --- a/test/extras_test.py +++ b/test/advanced_features/extras_test.py @@ -11,8 +11,8 @@ class TestExtrasWorking(unittest.TestCase): "codeformer_visibility": 0, "codeformer_weight": 0, "upscaling_resize": 2, - "upscaling_resize_w": 512, - "upscaling_resize_h": 512, + "upscaling_resize_w": 128, + "upscaling_resize_h": 128, "upscaling_crop": True, "upscaler_1": "None", "upscaler_2": "None", diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py new file mode 100644 index 000000000..36ed7b9a9 --- /dev/null +++ b/test/advanced_features/txt2img_test.py @@ -0,0 +1,47 @@ +import unittest +import requests + + +class TestTxt2ImgWorking(unittest.TestCase): + def setUp(self): + self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" + self.simple_txt2img = { + "enable_hr": False, + "denoising_strength": 0, + "firstphase_width": 0, + "firstphase_height": 0, + "prompt": "example prompt", + "styles": [], + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "batch_size": 1, + "n_iter": 1, + "steps": 3, + "cfg_scale": 7, + "width": 64, + "height": 64, + "restore_faces": False, + "tiling": False, + "negative_prompt": "", + "eta": 0, + "s_churn": 0, + "s_tmax": 0, + "s_tmin": 0, + "s_noise": 1, + "sampler_index": "Euler a" + } + + def test_txt2img_with_restore_faces_performed(self): + self.simple_txt2img["restore_faces"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + + +class TestTxt2ImgCorrectness(unittest.TestCase): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/img2img_test.py b/test/basic_features/img2img_test.py similarity index 96% rename from test/img2img_test.py rename to test/basic_features/img2img_test.py index 012a95809..0a9c1e8ad 100644 --- a/test/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -51,9 +51,5 @@ class TestImg2ImgWorking(unittest.TestCase): self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) -class TestImg2ImgCorrectness(unittest.TestCase): - pass - - if __name__ == "__main__": unittest.main() diff --git a/test/txt2img_test.py b/test/basic_features/txt2img_test.py similarity index 97% rename from test/txt2img_test.py rename to test/basic_features/txt2img_test.py index ce7520858..fe4af9991 100644 --- a/test/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -68,9 +68,5 @@ class TestTxt2ImgWorking(unittest.TestCase): self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) -class TestTxt2ImgCorrectness(unittest.TestCase): - pass - - if __name__ == "__main__": unittest.main() diff --git a/test/utils_test.py b/test/basic_features/utils_test.py similarity index 97% rename from test/utils_test.py rename to test/basic_features/utils_test.py index be9e6bf8c..9706db8b4 100644 --- a/test/utils_test.py +++ b/test/basic_features/utils_test.py @@ -60,4 +60,8 @@ class UtilsTests(unittest.TestCase): self.assertEqual(requests.get(self.url_artist_categories).status_code, 200) def test_artists(self): - self.assertEqual(requests.get(self.url_artists).status_code, 200) \ No newline at end of file + self.assertEqual(requests.get(self.url_artists).status_code, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/server_poll.py b/test/server_poll.py index 8e63b4502..c71e906a6 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -3,7 +3,7 @@ import requests import time -def run_tests(proc): +def run_tests(proc, test_dir): timeout_threshold = 240 start_time = time.time() while time.time()-start_time < timeout_threshold: @@ -14,7 +14,9 @@ def run_tests(proc): if proc.poll() is not None: break if proc.poll() is None: - suite = unittest.TestLoader().discover('', pattern='*_test.py') + if test_dir is None: + test_dir = "" + suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") result = unittest.TextTestRunner(verbosity=2).run(suite) else: print("Launch unsuccessful") From 3ffc1c6ceee169fac767a956fd0d4f153b005dbf Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:45:21 +0300 Subject: [PATCH 09/72] skip cuda test --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index a56a81103..f30486817 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -22,5 +22,5 @@ jobs: ${{ runner.os }}-pip- - name: Run tests run: | - export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all" + export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test" python launch.py From 0646040667b59526ac8346d53efd14dc0e75b01e Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 14:36:07 +0300 Subject: [PATCH 10/72] Propagate test error and try it without localhost --- launch.py | 7 ++++--- test/advanced_features/extras_test.py | 2 +- test/advanced_features/txt2img_test.py | 2 +- test/basic_features/img2img_test.py | 2 +- test/basic_features/txt2img_test.py | 2 +- test/basic_features/utils_test.py | 22 +++++++++++----------- test/server_poll.py | 4 +++- 7 files changed, 22 insertions(+), 19 deletions(-) diff --git a/launch.py b/launch.py index d0f502c2b..8ed1dffca 100644 --- a/launch.py +++ b/launch.py @@ -234,8 +234,8 @@ def prepare_enviroment(): exit(0) if run_tests: - tests(test_dir) - exit(0) + exitcode = tests(test_dir) + exit(exitcode) def tests(test_dir): @@ -251,10 +251,11 @@ def tests(test_dir): proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) import test.server_poll - test.server_poll.run_tests(proc, test_dir) + exitcode = test.server_poll.run_tests(proc, test_dir) print(f"Stopping Web UI process with id {proc.pid}") proc.kill() + return exitcode def start(): diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py index 8763f8ed1..abdd5aa21 100644 --- a/test/advanced_features/extras_test.py +++ b/test/advanced_features/extras_test.py @@ -3,7 +3,7 @@ import unittest class TestExtrasWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" + self.url_img2img = "http://127.0.0.1:7860/sdapi/v1/extra-single-image" self.simple_extras = { "resize_mode": 0, "show_extras_results": True, diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py index 36ed7b9a9..6ab5a2422 100644 --- a/test/advanced_features/txt2img_test.py +++ b/test/advanced_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" + self.url_txt2img = "http://127.0.0.1:7860/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index 0a9c1e8ad..b9a341135 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -6,7 +6,7 @@ from PIL import Image class TestImg2ImgWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" + self.url_img2img = "http://127.0.0.1:7860/sdapi/v1/img2img" self.simple_img2img = { "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], "resize_mode": 0, diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index fe4af9991..0be675d92 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" + self.url_txt2img = "http://127.0.0.1:7860/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 9706db8b4..592f76196 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -3,17 +3,17 @@ import requests class UtilsTests(unittest.TestCase): def setUp(self): - self.url_options = "http://localhost:7860/sdapi/v1/options" - self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" - self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" - self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" - self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" - self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" - self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" - self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" - self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" - self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" - self.url_artists = "http://localhost:7860/sdapi/v1/artists" + self.url_options = "http://127.0.0.1:7860/sdapi/v1/options" + self.url_cmd_flags = "http://127.0.0.1:7860/sdapi/v1/cmd-flags" + self.url_samplers = "http://127.0.0.1:7860/sdapi/v1/samplers" + self.url_upscalers = "http://127.0.0.1:7860/sdapi/v1/upscalers" + self.url_sd_models = "http://127.0.0.1:7860/sdapi/v1/sd-models" + self.url_hypernetworks = "http://127.0.0.1:7860/sdapi/v1/hypernetworks" + self.url_face_restorers = "http://127.0.0.1:7860/sdapi/v1/face-restorers" + self.url_realesrgan_models = "http://127.0.0.1:7860/sdapi/v1/realesrgan-models" + self.url_prompt_styles = "http://127.0.0.1:7860/sdapi/v1/prompt-styles" + self.url_artist_categories = "http://127.0.0.1:7860/sdapi/v1/artist-categories" + self.url_artists = "http://127.0.0.1:7860/sdapi/v1/artists" def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) diff --git a/test/server_poll.py b/test/server_poll.py index c71e906a6..028fd476b 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -8,7 +8,7 @@ def run_tests(proc, test_dir): start_time = time.time() while time.time()-start_time < timeout_threshold: try: - requests.head("http://localhost:7860/") + requests.head("http://127.0.0.1:7860/") break except requests.exceptions.ConnectionError: if proc.poll() is not None: @@ -18,5 +18,7 @@ def run_tests(proc, test_dir): test_dir = "" suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") result = unittest.TextTestRunner(verbosity=2).run(suite) + return len(result.failures) else: print("Launch unsuccessful") + return 1 From 7416ac8d3cadcc6a53bbcc41e1cd184fa1587afd Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 14:55:39 +0300 Subject: [PATCH 11/72] Use localhost with 80 port, count errors as well --- test/advanced_features/extras_test.py | 2 +- test/advanced_features/txt2img_test.py | 2 +- test/basic_features/img2img_test.py | 2 +- test/basic_features/txt2img_test.py | 2 +- test/basic_features/utils_test.py | 22 +++++++++++----------- test/server_poll.py | 4 ++-- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py index abdd5aa21..4b8ae25ae 100644 --- a/test/advanced_features/extras_test.py +++ b/test/advanced_features/extras_test.py @@ -3,7 +3,7 @@ import unittest class TestExtrasWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://127.0.0.1:7860/sdapi/v1/extra-single-image" + self.url_img2img = "http://localhost:80/sdapi/v1/extra-single-image" self.simple_extras = { "resize_mode": 0, "show_extras_results": True, diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py index 6ab5a2422..e6c3531a7 100644 --- a/test/advanced_features/txt2img_test.py +++ b/test/advanced_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://127.0.0.1:7860/sdapi/v1/txt2img" + self.url_txt2img = "http://localhost:80/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index b9a341135..c4c9a90fe 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -6,7 +6,7 @@ from PIL import Image class TestImg2ImgWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://127.0.0.1:7860/sdapi/v1/img2img" + self.url_img2img = "http://localhost:80/sdapi/v1/img2img" self.simple_img2img = { "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], "resize_mode": 0, diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 0be675d92..4f1fc77d3 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://127.0.0.1:7860/sdapi/v1/txt2img" + self.url_txt2img = "http://localhost:80/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 592f76196..fdb72b9aa 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -3,17 +3,17 @@ import requests class UtilsTests(unittest.TestCase): def setUp(self): - self.url_options = "http://127.0.0.1:7860/sdapi/v1/options" - self.url_cmd_flags = "http://127.0.0.1:7860/sdapi/v1/cmd-flags" - self.url_samplers = "http://127.0.0.1:7860/sdapi/v1/samplers" - self.url_upscalers = "http://127.0.0.1:7860/sdapi/v1/upscalers" - self.url_sd_models = "http://127.0.0.1:7860/sdapi/v1/sd-models" - self.url_hypernetworks = "http://127.0.0.1:7860/sdapi/v1/hypernetworks" - self.url_face_restorers = "http://127.0.0.1:7860/sdapi/v1/face-restorers" - self.url_realesrgan_models = "http://127.0.0.1:7860/sdapi/v1/realesrgan-models" - self.url_prompt_styles = "http://127.0.0.1:7860/sdapi/v1/prompt-styles" - self.url_artist_categories = "http://127.0.0.1:7860/sdapi/v1/artist-categories" - self.url_artists = "http://127.0.0.1:7860/sdapi/v1/artists" + self.url_options = "http://localhost:80/sdapi/v1/options" + self.url_cmd_flags = "http://localhost:80/sdapi/v1/cmd-flags" + self.url_samplers = "http://localhost:80/sdapi/v1/samplers" + self.url_upscalers = "http://localhost:80/sdapi/v1/upscalers" + self.url_sd_models = "http://localhost:80/sdapi/v1/sd-models" + self.url_hypernetworks = "http://localhost:80/sdapi/v1/hypernetworks" + self.url_face_restorers = "http://localhost:80/sdapi/v1/face-restorers" + self.url_realesrgan_models = "http://localhost:80/sdapi/v1/realesrgan-models" + self.url_prompt_styles = "http://localhost:80/sdapi/v1/prompt-styles" + self.url_artist_categories = "http://localhost:80/sdapi/v1/artist-categories" + self.url_artists = "http://localhost:80/sdapi/v1/artists" def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) diff --git a/test/server_poll.py b/test/server_poll.py index 028fd476b..e4462b6c9 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -8,7 +8,7 @@ def run_tests(proc, test_dir): start_time = time.time() while time.time()-start_time < timeout_threshold: try: - requests.head("http://127.0.0.1:7860/") + requests.head("http://localhost:80/") break except requests.exceptions.ConnectionError: if proc.poll() is not None: @@ -18,7 +18,7 @@ def run_tests(proc, test_dir): test_dir = "" suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") result = unittest.TextTestRunner(verbosity=2).run(suite) - return len(result.failures) + return len(result.failures) + len(result.errors) else: print("Launch unsuccessful") return 1 From 5808241dd76983212ab8e27b07f72866671f0b2d Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 15:14:52 +0300 Subject: [PATCH 12/72] Use 80 port on launch --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index f30486817..223a31b93 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -22,5 +22,5 @@ jobs: ${{ runner.os }}-pip- - name: Run tests run: | - export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test" + export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --port 80" python launch.py From 9e4f68acad4697fdf10004eb85d3f6f769c2c70b Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 18:40:15 +0300 Subject: [PATCH 13/72] Stop exporting cl args and upload stdout and stderr as artifacts --- .github/workflows/run_tests.yaml | 11 +++++++++-- launch.py | 2 ++ test/advanced_features/extras_test.py | 2 +- test/advanced_features/txt2img_test.py | 2 +- test/basic_features/img2img_test.py | 2 +- test/basic_features/txt2img_test.py | 2 +- test/basic_features/utils_test.py | 22 +++++++++++----------- test/server_poll.py | 2 +- 8 files changed, 27 insertions(+), 18 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 223a31b93..558b0c615 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -22,5 +22,12 @@ jobs: ${{ runner.os }}-pip- - name: Run tests run: | - export COMMANDLINE_ARGS="--tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --port 80" - python launch.py + python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + - name: Upload main app stdout-stderr + uses: actions/upload-artifact@v3 + if: always() + with: + name: stdout-stderr + path: | + ./test/stdout.txt + ./test/stderr.txt diff --git a/launch.py b/launch.py index 8ed1dffca..ed21d0e68 100644 --- a/launch.py +++ b/launch.py @@ -244,6 +244,8 @@ def tests(test_dir): if "--ckpt" not in sys.argv: sys.argv.append("--ckpt") sys.argv.append("./test/test_files/empty.pt") + if "--skip-torch-cuda-test" not in sys.argv: + sys.argv.append("--skip-torch-cuda-test") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py index 4b8ae25ae..8763f8ed1 100644 --- a/test/advanced_features/extras_test.py +++ b/test/advanced_features/extras_test.py @@ -3,7 +3,7 @@ import unittest class TestExtrasWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://localhost:80/sdapi/v1/extra-single-image" + self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" self.simple_extras = { "resize_mode": 0, "show_extras_results": True, diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py index e6c3531a7..36ed7b9a9 100644 --- a/test/advanced_features/txt2img_test.py +++ b/test/advanced_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://localhost:80/sdapi/v1/txt2img" + self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index c4c9a90fe..0a9c1e8ad 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -6,7 +6,7 @@ from PIL import Image class TestImg2ImgWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://localhost:80/sdapi/v1/img2img" + self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" self.simple_img2img = { "init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))], "resize_mode": 0, diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 4f1fc77d3..fe4af9991 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -4,7 +4,7 @@ import requests class TestTxt2ImgWorking(unittest.TestCase): def setUp(self): - self.url_txt2img = "http://localhost:80/sdapi/v1/txt2img" + self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" self.simple_txt2img = { "enable_hr": False, "denoising_strength": 0, diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index fdb72b9aa..9706db8b4 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -3,17 +3,17 @@ import requests class UtilsTests(unittest.TestCase): def setUp(self): - self.url_options = "http://localhost:80/sdapi/v1/options" - self.url_cmd_flags = "http://localhost:80/sdapi/v1/cmd-flags" - self.url_samplers = "http://localhost:80/sdapi/v1/samplers" - self.url_upscalers = "http://localhost:80/sdapi/v1/upscalers" - self.url_sd_models = "http://localhost:80/sdapi/v1/sd-models" - self.url_hypernetworks = "http://localhost:80/sdapi/v1/hypernetworks" - self.url_face_restorers = "http://localhost:80/sdapi/v1/face-restorers" - self.url_realesrgan_models = "http://localhost:80/sdapi/v1/realesrgan-models" - self.url_prompt_styles = "http://localhost:80/sdapi/v1/prompt-styles" - self.url_artist_categories = "http://localhost:80/sdapi/v1/artist-categories" - self.url_artists = "http://localhost:80/sdapi/v1/artists" + self.url_options = "http://localhost:7860/sdapi/v1/options" + self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" + self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" + self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" + self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" + self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" + self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" + self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" + self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" + self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" + self.url_artists = "http://localhost:7860/sdapi/v1/artists" def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) diff --git a/test/server_poll.py b/test/server_poll.py index e4462b6c9..d4df697b2 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -8,7 +8,7 @@ def run_tests(proc, test_dir): start_time = time.time() while time.time()-start_time < timeout_threshold: try: - requests.head("http://localhost:80/") + requests.head("http://localhost:7860/") break except requests.exceptions.ConnectionError: if proc.poll() is not None: From 4a35c3744c68be470e6caf72689322ed58f90aac Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 18:57:14 +0300 Subject: [PATCH 14/72] remove test requiring codeformers --- test/basic_features/txt2img_test.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index fe4af9991..1c2674b2a 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -49,10 +49,6 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["enable_hr"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - def test_txt2img_with_restore_faces_performed(self): - self.simple_txt2img["restore_faces"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - def test_txt2img_with_tiling_performed(self): self.simple_txt2img["tiling"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) From a07107900019c8f18ed0f8d07e39435ee3189e79 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 14 Nov 2022 19:22:06 +0300 Subject: [PATCH 15/72] Use empty model as CLIP weights --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 558b0c615..4aeeab9c7 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -22,7 +22,7 @@ jobs: ${{ runner.os }}-pip- - name: Run tests run: | - python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --clip-models-path ./test/test_files/empty.pt - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: [PATCH 16/72] Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- modules/esrgan_model.py | 2 +- modules/scunet_model.py | 3 ++- modules/swinir_model.py | 2 +- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 67165bf66..a87d0d4c9 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c61669b46..9a9c38f1f 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 59532274f..36a996bf0 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,8 +54,9 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), device) + img = img.unsqueeze(0).to(device) + img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 4253b66d5..facd262db 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old From a5106a7cdc24153332e4eb1d28e66ea1d7f1ef79 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:44:27 -0500 Subject: [PATCH 17/72] Remove extra .to(device) --- modules/scunet_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 36a996bf0..523602413 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -56,7 +56,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(device) - img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() From 694611cbd8151533cf18e01a8ead32771ea12ff3 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 19 Nov 2022 00:56:08 +0300 Subject: [PATCH 18/72] Apply suggestions from code review Use last version of setup-python action Remove unnecesarry multicommand from run Remove current directory from artifact paths Co-authored-by: Margen67 --- .github/workflows/run_tests.yaml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 4aeeab9c7..6c6f15a89 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -11,7 +11,7 @@ jobs: - name: Checkout Code uses: actions/checkout@v3 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.10.6 - uses: actions/cache@v2 @@ -21,13 +21,12 @@ jobs: restore-keys: | ${{ runner.os }}-pip- - name: Run tests - run: | - python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --clip-models-path ./test/test_files/empty.pt + run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --clip-models-path ./test/test_files/empty.pt - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() with: name: stdout-stderr path: | - ./test/stdout.txt - ./test/stderr.txt + test/stdout.txt + test/stderr.txt From 14dfede8ddbf6b82bb290d5c4292d52b6b68e3ca Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 19 Nov 2022 14:15:10 +0300 Subject: [PATCH 19/72] Minor fixes Remove unused test completely Change job name Don't use empty.pt as CLIP weights - it wont work. Use latest version of actions/cache --- .github/workflows/run_tests.yaml | 9 ++++----- test/basic_features/utils_test.py | 14 -------------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 6c6f15a89..49dc92bd9 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -1,4 +1,4 @@ -name: Run tests on CPU with empty model +name: Run basic features tests on CPU with empty SD model on: - push @@ -14,14 +14,13 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.10.6 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- + restore-keys: ${{ runner.os }}-pip- - name: Run tests - run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test --clip-models-path ./test/test_files/empty.pt + run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 9706db8b4..765470c90 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -18,20 +18,6 @@ class UtilsTests(unittest.TestCase): def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) - # def test_options_write(self): - # response = requests.get(self.url_options) - # self.assertEqual(response.status_code, 200) - - # pre_value = response.json()["send_seed"] - - # self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) - - # response = requests.get(self.url_options) - # self.assertEqual(response.status_code, 200) - # self.assertEqual(response.json()["send_seed"], not pre_value) - - # requests.post(self.url_options, json={"send_seed": pre_value}) - def test_cmd_flags(self): self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) From ac7ecd2d847bf4e3a9503db0f2a291e32b82302c Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Sat, 19 Nov 2022 14:49:22 -0500 Subject: [PATCH 20/72] Label and load SD .safetensors model files --- README.md | 1 + modules/modelloader.py | 1 + modules/sd_models.py | 24 ++++++++++++++++-------- requirements.txt | 1 + 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 33508f311..ba9f3952f 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) +- Can use safetensors to safely load model files without python pickle ## Where are Aesthetic Gradients?!?! Aesthetic Gradients are now an extension. You can install it using git: diff --git a/modules/modelloader.py b/modules/modelloader.py index e4a6f8acb..7d2f0ade9 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,6 +82,7 @@ def cleanup_models(): src_path = models_path dest_path = os.path.join(models_path, "Stable-diffusion") move_files(src_path, dest_path, ".ckpt") + move_files(src_path, dest_path, ".safetensors") src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e03..4ccdf30b4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,6 +4,7 @@ import sys import gc from collections import namedtuple import torch +from safetensors.torch import load_file import re from omegaconf import OmegaConf @@ -16,9 +17,10 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config', 'exttype']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() +checkpoint_types = {'.ckpt':'pickle','.safetensors':'safetensors'} try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -45,7 +47,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt",".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -60,15 +62,15 @@ def list_models(): if name.startswith("\\") or name.startswith("/"): name = name[1:] - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + shortname, ext = os.path.splitext(name.replace("/", "_").replace("\\", "_")) - return f'{name} [{shorthash}]', shortname + return f'{name} [{checkpoint_types[ext]}] [{shorthash}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config, '') shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -76,12 +78,12 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) + basename, ext = os.path.splitext(filename) config = basename + ".yaml" if not os.path.exists(config): config = shared.cmd_opts.config - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config, ext) def get_closet_checkpoint_match(searchString): @@ -173,7 +175,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'): + # safely load weights + # TODO: safetensors supports zero copy fast load to gpu, see issue #684 + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") diff --git a/requirements.txt b/requirements.txt index 762db4f34..f7de9f707 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia lark inflection GitPython +safetensors From bd68e35de3b7cf7547ed97d8bdf60147402133cc Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:35:26 +0900 Subject: [PATCH 21/72] Gradient accumulation, autocast fix, new latent sampling method, etc --- modules/hypernetworks/hypernetwork.py | 251 ++++++++-------- modules/sd_hijack.py | 9 +- modules/sd_hijack_checkpoint.py | 10 + modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 120 +++++--- .../textual_inversion/textual_inversion.py | 272 ++++++++++-------- modules/ui.py | 16 +- 7 files changed, 408 insertions(+), 273 deletions(-) create mode 100644 modules/sd_hijack_checkpoint.py diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index fbb87dd14..3d3301b08 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -367,13 +367,13 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -403,28 +403,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() - ititial_step = hypernetwork.step or 0 - if ititial_step >= steps: + initial_step = hypernetwork.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return hypernetwork, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - - size = len(ds.indexes) - loss_dict = defaultdict(lambda : deque(maxlen = 1024)) - losses = torch.zeros((size,)) - previous_mean_losses = [0] - previous_mean_loss = 0 - print("Mean loss of {} elements".format(size)) weights = hypernetwork.weights() for weight in weights: @@ -436,8 +432,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) - optimizer_name = 'AdamW' + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. try: @@ -446,131 +442,155 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log print("Cannot resume from saved optimizer!") print(e) + scaler = torch.cuda.amp.GradScaler() + + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + # size = len(ds.indexes) + # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + # losses = torch.zeros((size,)) + # previous_mean_losses = [0] + # previous_mean_loss = 0 + # print("Mean loss of {} elements".format(size)) + steps_without_grad = 0 last_saved_file = "" last_saved_image = "" forced_filename = "" - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + if shared.state.interrupted: + break - if shared.state.interrupted: - break + with torch.autocast("cuda"): + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + if tag_drop_out != 0 or shuffle_tags: + shared.sd_model.cond_stage_model.to(devices.device) + c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory) + shared.sd_model.cond_stage_model.to(devices.cpu) + else: + c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + del c - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c + _loss_step += loss.item() + scaler.scale(loss).backward() + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}") + # scaler.unscale_(optimizer) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0) + # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}") + scaler.step(optimizer) + scaler.update() + hypernetwork.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) + steps_done = hypernetwork.step + 1 - optimizer.zero_grad() - weights[0].grad = None - loss.backward() + epoch_num = hypernetwork.step // steps_per_epoch + epoch_step = hypernetwork.step % steps_per_epoch - if weights[0].grad is None: - steps_without_grad += 1 - else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + hypernetwork.optimizer_name = optimizer_name + if shared.opts.save_optimizer_state: + hypernetwork.optimizer_state_dict = optimizer.state_dict() + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) + hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. - optimizer.step() + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) - steps_done = hypernetwork.step + 1 + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") - - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) - else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') - hypernetwork.optimizer_name = optimizer_name - if shared.opts.save_optimizer_state: - hypernetwork.optimizer_state_dict = optimizer.state_dict() - save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) - hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + preview_text = p.prompt - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - preview_text = p.prompt + shared.state.job_no = hypernetwork.step - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None - - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

-Loss: {previous_mean_loss:.7f}
+Loss: {loss_step:.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - report_statistics(loss_dict) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + finally: + pbar.leave = False + pbar.close() + #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name @@ -579,6 +599,9 @@ Last saved image: {html.escape(last_saved_image)}
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) del optimizer hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + return hypernetwork, filename def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13e..29c8b5613 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -59,6 +59,10 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 +def fix_checkpoint(): + ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward class StableDiffusionModelHijack: fixes = None @@ -78,6 +82,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model apply_optimizations() + fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py new file mode 100644 index 000000000..5712972f1 --- /dev/null +++ b/modules/sd_hijack_checkpoint.py @@ -0,0 +1,10 @@ +from torch.utils.checkpoint import checkpoint + +def BasicTransformerBlock_forward(self, x, context=None): + return checkpoint(self._forward, x, context) + +def AttentionBlock_forward(self, x): + return checkpoint(self._forward, x) + +def ResBlock_forward(self, x, emb): + return checkpoint(self._forward, x, emb) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index a4457305b..3704ce239 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), - "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), - "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index eb75c3769..d594b49d2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,7 +3,7 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from torchvision import transforms import random @@ -11,25 +11,28 @@ import tqdm from modules import devices, shared import re +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, latent=None, filename_text=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename - self.latent = latent self.filename_text = filename_text - self.cond = None - self.cond_text = None + self.latent_dist = latent_dist + self.latent_sample = latent_sample + self.cond = cond + self.cond_text = cond_text + self.pixel_values = pixel_values class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None - + self.placeholder_token = placeholder_token - self.batch_size = batch_size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) @@ -45,11 +48,16 @@ class PersonalizedBase(Dataset): assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - cond_model = shared.sd_model.cond_stage_model - self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + + + self.shuffle_tags = shuffle_tags + self.tag_drop_out = tag_drop_out + print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): + if shared.state.interrupted: + raise Exception("inturrupted") try: image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: @@ -71,37 +79,58 @@ class PersonalizedBase(Dataset): npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) - torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) - torchdata = torch.moveaxis(torchdata, 2, 0) + torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) + latent_sample = None - init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() - init_latent = init_latent.to(devices.cpu) + with torch.autocast("cuda"): + latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) + if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + latent_sampling_method = "once" + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "deterministic": + # Works only for DiagonalGaussianDistribution + latent_dist.std = 0 + latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + elif latent_sampling_method == "random": + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) - if include_cond: + if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) - entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + + if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): + with torch.autocast("cuda"): + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # elif not include_cond: + # _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text]) + # max_n = token_count // 75 + # index_list = [ [] for _ in range(max_n + 1) ] + # for n, (z, _) in hijack_fixes[0]: + # index_list[n].append(z) + # with torch.autocast("cuda"): + # entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) + # entry.emb_index = index_list self.dataset.append(entry) + del torchdata + del latent_dist + del latent_sample - assert len(self.dataset) > 0, "No images have been found in the dataset." - self.length = len(self.dataset) * repeats // batch_size - - self.dataset_length = len(self.dataset) - self.indexes = None - self.shuffle() - - def shuffle(self): - self.indexes = np.random.permutation(self.dataset_length) + self.length = len(self.dataset) + assert self.length > 0, "No images have been found in the dataset." + self.batch_size = min(batch_size, self.length) + self.gradient_step = min(gradient_step, self.length // self.batch_size) + self.latent_sampling_method = latent_sampling_method def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') - if shared.opts.tag_drop_out != 0: - tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] - if shared.opts.shuffle_tags: + if self.tag_drop_out != 0: + tags = [t for t in tags if random.random() > self.tag_drop_out] + if self.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) return text @@ -110,19 +139,28 @@ class PersonalizedBase(Dataset): return self.length def __getitem__(self, i): - res = [] + entry = self.dataset[i] + if self.tag_drop_out != 0 or self.shuffle_tags: + entry.cond_text = self.create_text(entry.filename_text) + if self.latent_sampling_method == "random": + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist) + return entry - for j in range(self.batch_size): - position = i * self.batch_size + j - if position % len(self.indexes) == 0: - self.shuffle() +class PersonalizedDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs) + self.collate_fn = collate_wrapper + - index = self.indexes[position % len(self.indexes)] - entry = self.dataset[index] +class BatchLoader: + def __init__(self, data): + self.cond_text = [entry.cond_text for entry in data] + self.cond = [entry.cond for entry in data] + self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - if entry.cond is None: - entry.cond_text = self.create_text(entry.filename_text) + def pin_memory(self): + self.latent_sample = self.latent_sample.pin_memory() + return self - res.append(entry) - - return res +def collate_wrapper(batch): + return BatchLoader(batch) \ No newline at end of file diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5e4d8688b..1d5e3a322 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -184,7 +184,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return - if (step + 1) % shared.opts.training_write_csv_every != 0: + if step % shared.opts.training_write_csv_every != 0: return write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True @@ -194,21 +194,23 @@ def write_loss(log_directory, filename, step, epoch_len, values): if write_csv_header: csv_writer.writeheader() - epoch = step // epoch_len - epoch_step = step % epoch_len + epoch = (step - 1) // epoch_len + epoch_step = (step - 1) % epoch_len csv_writer.writerow({ - "step": step + 1, + "step": step, "epoch": epoch, - "epoch_step": epoch_step + 1, + "epoch_step": epoch_step, **values, }) -def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" assert batch_size > 0, "Batch size must be positive" + assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" + assert gradient_step > 0, "Gradient accumulation step must be positive" assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" @@ -224,10 +226,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -255,161 +257,205 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc else: images_embeds_dir = None - cond_model = shared.sd_model.cond_stage_model - hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] checkpoint = sd_models.select_checkpoint() - ititial_step = embedding.step or 0 - if ititial_step >= steps: + initial_step = embedding.step or 0 + if initial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - - # dataset loading may take a while, so input validations and early returns should be done before this + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + + pin_memory = shared.opts.pin_memory + + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False) + if unload: shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + scaler = torch.cuda.amp.GradScaler() - losses = torch.zeros((32,)) + batch_size = ds.batch_size + gradient_step = ds.gradient_step + # n steps = batch_size * gradient_step * n image processed + steps_per_epoch = len(ds) // batch_size // gradient_step + max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step + loss_step = 0 + _loss_step = 0 #internal + last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False + + pbar = tqdm.tqdm(total=steps - initial_step) + try: + for i in range((steps-initial_step) * gradient_step): + if scheduler.finished: + break + if shared.state.interrupted: + break + for j, batch in enumerate(dl): + # works as a drop_last=True for gradient accumulation + if j == max_steps_per_epoch: + break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + if shared.state.interrupted: + break - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step + with torch.autocast("cuda"): + # c = stack_conds(batch.cond).to(devices.device) + # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) + # print(mask) + # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) + x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) + c = shared.sd_model.cond_stage_model(batch.cond_text) + loss = shared.sd_model(x, c)[0] / gradient_step + del x + + _loss_step += loss.item() + scaler.scale(loss).backward() + + # go back until we reach gradient accumulation steps + if (j + 1) % gradient_step != 0: + continue + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #scaler.unscale_(optimizer) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + #torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0) + #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") + scaler.step(optimizer) + scaler.update() + embedding.step += 1 + pbar.update() + optimizer.zero_grad(set_to_none=True) + loss_step = _loss_step + _loss_step = 0 - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break + steps_done = embedding.step + 1 - if shared.state.interrupted: - break + epoch_num = embedding.step // steps_per_epoch + epoch_step = embedding.step % steps_per_epoch - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + #if shared.opts.save_optimizer_state: + #embedding.optimizer_state_dict = optimizer.state_dict() + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + embedding_yet_to_be_embedded = True - losses[embedding.step % losses.shape[0]] = loss.item() + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { + "loss": f"{loss_step:.7f}", + "learn_rate": scheduler.learn_rate + }) - optimizer.zero_grad() - loss.backward() - optimizer.step() + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - steps_done = embedding.step + 1 + shared.sd_model.first_stage_model.to(devices.device) - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = batch.cond_text[0] + p.steps = 20 + p.width = training_width + p.height = training_height - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding_name_every = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') - save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) - embedding_yet_to_be_embedded = True + preview_text = p.prompt - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.device) + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_name = sd_samplers.samplers[preview_sampler_index].name - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - preview_text = p.prompt + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - processed = processing.process_images(p) - image = processed.images[0] + title = "<{}>".format(data.get('name', '???')) - if unload: - shared.sd_model.first_stage_model.to(devices.cpu) + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - shared.state.current_image = image + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - title = "<{}>".format(data.get('name', '???')) + shared.state.job_no = embedding.step - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' - - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) - - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) - - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False - - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = embedding.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

-Loss: {losses.mean():.7f}
+Loss: {loss_step:.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entries[0].cond_text)}
+Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) - shared.sd_model.first_stage_model.to(devices.device) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + except Exception: + print(traceback.format_exc(), file=sys.stderr) + pass + finally: + pbar.leave = False + pbar.close() + shared.sd_model.first_stage_model.to(devices.device) return embedding, filename diff --git a/modules/ui.py b/modules/ui.py index a5953fce5..9d2a1cbfa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1262,7 +1262,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): with gr.Row(): interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') + run_preprocess = gr.Button(value="Preprocess", variant='primary') process_split.change( fn=lambda show: gr_show(show), @@ -1289,6 +1289,7 @@ def create_ui(wrap_gradio_gpu_call): hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") batch_size = gr.Number(label='Batch size', value=1, precision=0) + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) @@ -1299,6 +1300,11 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + with gr.Row(): + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + with gr.Row(): + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1387,11 +1393,15 @@ def create_ui(wrap_gradio_gpu_call): train_embedding_name, embedding_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file, @@ -1412,11 +1422,15 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork_name, hypernetwork_learn_rate, batch_size, + gradient_step, dataset_directory, log_directory, training_width, training_height, steps, + shuffle_tags, + tag_drop_out, + latent_sampling_method, create_image_every, save_embedding_every, template_file, From a4a5735d0a80218e59f8a6e8401726f7209a6a8d Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:38:18 +0900 Subject: [PATCH 22/72] remove unnecessary comment --- modules/textual_inversion/dataset.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index d594b49d2..1dd53b850 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -103,15 +103,6 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with torch.autocast("cuda"): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - # elif not include_cond: - # _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text]) - # max_n = token_count // 75 - # index_list = [ [] for _ in range(max_n + 1) ] - # for n, (z, _) in hijack_fixes[0]: - # index_list[n].append(z) - # with torch.autocast("cuda"): - # entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - # entry.emb_index = index_list self.dataset.append(entry) del torchdata From 2d22d72cdaaf2b78b2986b841d478c11ac855dd2 Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 16:14:27 +0900 Subject: [PATCH 23/72] fix random sampling with pin_memory --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 1dd53b850..110c0e09b 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -134,7 +134,7 @@ class PersonalizedBase(Dataset): if self.tag_drop_out != 0 or self.shuffle_tags: entry.cond_text = self.create_text(entry.filename_text) if self.latent_sampling_method == "random": - entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist) + entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry class PersonalizedDataLoader(DataLoader): From 471189743a5f3dd7a5c822e63aad28f950abbd94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20B=C3=B6er?= Date: Sun, 20 Nov 2022 15:57:43 +0100 Subject: [PATCH 24/72] Move progress info to beginning of title because who has so few tabs open that they can see the end of a tab name? --- javascript/progressbar.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 671fde340..43d1d1ce0 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -23,7 +23,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(progressbar.innerText){ - let newtitle = 'Stable Diffusion - ' + progressbar.innerText + let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; if(document.title != newtitle){ document.title = newtitle; } From 637815632f9f362c9959e53139d37e88ea9ace6f Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:36:05 -0500 Subject: [PATCH 25/72] Generalize SD torch load/save to implement safetensor merging compat --- modules/extras.py | 15 +- modules/sd_models.py | 25 +- modules/ui.py | 3626 +++++++++++++++++++++--------------------- 3 files changed, 1840 insertions(+), 1826 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 71b93a068..820427de4 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -249,7 +249,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, save_as_safetensors, custom_name): def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -264,16 +264,16 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) print(f"Loading {primary_model_info.filename}...") - primary_model = torch.load(primary_model_info.filename, map_location='cpu') + primary_model = sd_models.torch_load(primary_model_info.filename, primary_model_info, map_override='cpu') theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) print(f"Loading {secondary_model_info.filename}...") - secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') + secondary_model = sd_models.torch_load(secondary_model_info.filename, primary_model_info, map_override='cpu') theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) if teritary_model_info is not None: print(f"Loading {teritary_model_info.filename}...") - teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') + teritary_model = sd_models.torch_load(teritary_model_info.filename, teritary_model_info, map_override='cpu') theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) else: teritary_model = None @@ -314,12 +314,13 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' - filename = filename if custom_name == '' else (custom_name + '.ckpt') + output_exttype = '.safetensors' if save_as_safetensors else '.ckpt' + filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged' + output_exttype + filename = filename if custom_name == '' else (custom_name + output_exttype) output_modelname = os.path.join(ckpt_dir, filename) print(f"Saving to {output_modelname}...") - torch.save(primary_model, output_modelname) + sd_models.torch_save(primary_model, output_modelname) sd_models.list_models() diff --git a/modules/sd_models.py b/modules/sd_models.py index 4ccdf30b4..2f8c2c48e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys import gc from collections import namedtuple import torch -from safetensors.torch import load_file +from safetensors.torch import load_file, save_file import re from omegaconf import OmegaConf @@ -143,6 +143,22 @@ def transform_checkpoint_dict_key(k): return k +def torch_load(model_filename, model_info, map_override=None): + map_override=shared.weight_load_location if not map_override else map_override + if(checkpoint_types[model_info.exttype] == 'safetensors'): + # safely load weights + # TODO: safetensors supports zero copy fast load to gpu, see issue #684 + return load_file(model_filename, device=map_override) + else: + return torch.load(model_filename, map_location=map_override) + +def torch_save(model, output_filename): + basename, exttype = os.path.splitext(output_filename) + if(checkpoint_types[exttype] == 'safetensors'): + # [===== >] Reticulating brines... + save_file(model, output_filename, metadata={"format": "pt"}) + else: + torch.save(model, output_filename) def get_state_dict_from_checkpoint(pl_sd): if "state_dict" in pl_sd: @@ -175,12 +191,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'): - # safely load weights - # TODO: safetensors supports zero copy fast load to gpu, see issue #684 - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) - else: - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + pl_sd = torch_load(checkpoint_file, checkpoint_info) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") diff --git a/modules/ui.py b/modules/ui.py index a5953fce5..a2b06aae7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,1812 +1,1814 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin - - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -if cmd_opts.deepdanbooru: - from modules.deepbooru import get_deepbooru_tags - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.ldsr_model -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok != None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -art_symbol = '\U0001f3a8' # 🎨 -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") - -def save_pil_to_file(pil_image, dir=None): - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in pil_image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True - - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) - return file_obj - - -# override save to file function so that it also writes PNG info -gr.processing_utils.save_pil_to_file = save_pil_to_file - - -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats - if run_memmon: - shared.mem_mon.monitor() - t = time.perf_counter() - - try: - res = list(func(*args, **kwargs)) - except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) - - shared.state.job = "" - shared.state.job_count = 0 - - if extra_outputs_array is None: - extra_outputs_array = [None, ''] - - res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] - - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - - if not add_stats: - return tuple(res) - - elapsed = time.perf_counter() - t - elapsed_m = int(elapsed // 60) - elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" - if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text - - if run_memmon: - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) - - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" - else: - vram_html = '' - - # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - - return tuple(res) - - return f - - -def calc_time_left(progress, threshold, label, force_display): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(): - with gr.Row(): - with gr.Box(): - with gr.Row(elem_id='seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id='random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') - - with gr.Box(elem_id='subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with gr.Row(visible=False) as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - with gr.Box(): - with gr.Row(elem_id='subseed_row'): - subseed = gr.Number(label='Variation seed', value=-1) - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id='random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) - - with gr.Row(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - - if cmd_opts.deepdanbooru: - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style.save_to_config = True - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style2.save_to_config = True - - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data[key] - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(): - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_folder_button = gr.Button(folder_symbol, elem_id=button_id) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - do_make_zip, - html_info, - ], - outputs=[ - download_files, - html_info, - html_info, - html_info, - ] - ) - else: - html_info_x = gr.HTML() - html_info = gr.HTML() - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info - - -def create_ui(wrap_gradio_gpu_call): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") - - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - - with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - enable_hr = gr.Checkbox(label='Highres. fix', value=False) - - with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) - - with gr.Row(equal_height=True): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) - - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - - with gr.Group(): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - firstphase_width, - firstphase_height, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - ) - - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - txt2img_prompt, - ], - outputs=[ - txt2img_prompt, - ] - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) - - with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - - with gr.Row(): - mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") - - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") - - with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) - - with gr.TabItem('Batch img2img', id='batch'): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) - - with gr.Row(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") - - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") - - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") - - with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - - with gr.Row(): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) - - with gr.Group(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - - with gr.Group(): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - if cmd_opts.deepdanbooru: - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - - roll.click( - fn=roll_artist, - _js="update_img2img_tokens", - inputs=[ - img2img_prompt, - ], - outputs=[ - img2img_prompt, - ] - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image'): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") - - with gr.TabItem('Batch Process'): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") - - with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") - show_extras_results = gr.Checkbox(label='Show result images', value=True) - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) - with gr.TabItem('Scale to'): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) - - result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False) - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Save as float16") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name") - initialization_text = gr.Textbox(label="Initialization text", value="*") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary') - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory') - process_dst = gr.Textbox(label='Destination directory') - process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - - batch_size = gr.Number(label='Batch size', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - - with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') - train_embedding = gr.Button(value="Train Embedding", variant='primary') - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary') - result = gr.HTML() - - settings_cols = 3 - items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') - - quicksettings_list = [] - - cols_displayed = 0 - items_displayed = 0 - previous_section = None - column = None - with gr.Row(elem_id="settings").style(equal_height=False): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): - if column is not None: - column.__exit__() - - column = gr.Column(variant='panel') - column.__enter__() - - items_displayed = 0 - cols_displayed += 1 - - previous_section = item.section - - elem_id, text = item.section - gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 - - with gr.Row(): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - if column is not None: - column.__exit__() - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in quicksettings_list: - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - # Since there are many dropdowns that shouldn't be saved, - # we only mark dropdowns that should be saved. - if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - apply_field(x, 'visible') - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin + + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +if cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.ldsr_model +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +import modules.textual_inversion.ui +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok != None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +art_symbol = '\U0001f3a8' # 🎨 +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") + +def save_pil_to_file(pil_image, dir=None): + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats + if run_memmon: + shared.mem_mon.monitor() + t = time.perf_counter() + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + # When printing out our debug argument list, do not print out more than a MB of text + max_debug_str_len = 131072 # (1024*1024)/8 + + print("Error completing request", file=sys.stderr) + argStr = f"Arguments: {str(args)} {str(kwargs)}" + print(argStr[:max_debug_str_len], file=sys.stderr) + if len(argStr) > max_debug_str_len: + print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) + + print(traceback.format_exc(), file=sys.stderr) + + shared.state.job = "" + shared.state.job_count = 0 + + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if elapsed_m > 0: + elapsed_text = f"{elapsed_m}m "+elapsed_text + + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' + + # last item is always HTML + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + + return tuple(res) + + return f + + +def calc_time_left(progress, threshold, label, force_display): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and progress > 0.02) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def roll_artist(prompt): + allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) + artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) + + return prompt + ", " + artist.name if prompt != '' else artist.name + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = get_deepbooru_tags(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(): + with gr.Row(): + with gr.Box(): + with gr.Row(elem_id='seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id='random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') + + with gr.Box(elem_id='subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with gr.Row(visible=False) as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + with gr.Box(): + with gr.Row(elem_id='subseed_row'): + subseed = gr.Number(label='Variation seed', value=-1) + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id='random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) + + with gr.Row(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + + if cmd_opts.deepdanbooru: + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style.save_to_config = True + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style2.save_to_config = True + + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data[key] + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(): + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' + open_folder_button = gr.Button(folder_symbol, elem_id=button_id) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + do_make_zip, + html_info, + ], + outputs=[ + download_files, + html_info, + html_info, + html_info, + ] + ) + else: + html_info_x = gr.HTML() + html_info = gr.HTML() + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info + + +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") + + with gr.Group(): + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + + with gr.Row(): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) + tiling = gr.Checkbox(label='Tiling', value=False) + enable_hr = gr.Checkbox(label='Highres. fix', value=False) + + with gr.Row(visible=False) as hr_options: + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + + with gr.Row(equal_height=True): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + + with gr.Group(): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + firstphase_width, + firstphase_height, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + ) + + roll.click( + fn=roll_artist, + _js="update_txt2img_tokens", + inputs=[ + txt2img_prompt, + ], + outputs=[ + txt2img_prompt, + ] + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (firstphase_width, "First pass size-1"), + (firstphase_height, "First pass size-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img'): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) + + with gr.TabItem('Inpaint', id='inpaint'): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) + + with gr.Row(): + mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") + + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") + + with gr.Row(): + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + + with gr.TabItem('Batch img2img', id='batch'): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + + with gr.Row(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") + + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + + with gr.Group(): + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") + + with gr.Row(): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) + tiling = gr.Checkbox(label='Tiling', value=False) + + with gr.Row(): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + with gr.Group(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) + + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + + with gr.Group(): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + if cmd_opts.deepdanbooru: + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + + roll.click( + fn=roll_artist, + _js="update_img2img_tokens", + inputs=[ + img2img_prompt, + ], + outputs=[ + img2img_prompt, + ] + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image'): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + + with gr.TabItem('Batch Process'): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + + with gr.TabItem('Batch from Directory'): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") + show_extras_results = gr.Checkbox(label='Show result images', value=True) + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by'): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) + with gr.TabItem('Scale to'): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) + + result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False) + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + custom_name = gr.Textbox(label="Custom Name (Optional)") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + save_as_half = gr.Checkbox(value=False, label="Save as float16") + save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name") + initialization_text = gr.Textbox(label="Initialization text", value="*") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary') + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory') + process_dst = gr.Textbox(label='Destination directory') + process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies') + process_split = gr.Checkbox(label='Split oversized images') + process_focal_crop = gr.Checkbox(label='Auto focal point crop') + process_caption = gr.Checkbox(label='Use BLIP for caption') + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_debug = gr.Checkbox(label='Create debug image') + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt") + run_preprocess = gr.Button(value="Preprocess", variant='primary') + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with gr.Row(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + with gr.Row(): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + + batch_size = gr.Number(label='Batch size', value=1, precision=0) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with gr.Row(variant="compact"): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + settings_submit = gr.Button(value="Apply settings", variant='primary') + result = gr.HTML() + + settings_cols = 3 + items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + + quicksettings_list = [] + + cols_displayed = 0 + items_displayed = 0 + previous_section = None + column = None + with gr.Row(elem_id="settings").style(equal_height=False): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): + if column is not None: + column.__exit__() + + column = gr.Column(variant='panel') + column.__enter__() + + items_displayed = 0 + cols_displayed += 1 + + previous_section = item.section + + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 + + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + if column is not None: + column.__exit__() + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in quicksettings_list: + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + save_as_safetensors, + custom_name, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + # Since there are many dropdowns that shouldn't be saved, + # we only mark dropdowns that should be saved. + if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + apply_field(x, 'visible') + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse From 927d24ef82f0eedebabb908c43829388f413d73c Mon Sep 17 00:00:00 2001 From: Liam Date: Sun, 20 Nov 2022 13:52:18 -0500 Subject: [PATCH 26/72] made selected_gallery_index query selectors more restrictive --- javascript/ui.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 95cfd106f..2ca66d79e 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -8,8 +8,8 @@ function set_theme(theme){ } function selected_gallery_index(){ - var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') - var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') + var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item') + var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2') var result = -1 buttons.forEach(function(v, i){ if(v==button) { result = i } }) From 5b57f61ba47f8b11d19a5b46e7fb5a52458abae5 Mon Sep 17 00:00:00 2001 From: flamelaw Date: Mon, 21 Nov 2022 10:15:46 +0900 Subject: [PATCH 27/72] fix pin_memory with different latent sampling method --- modules/hypernetworks/hypernetwork.py | 5 +++- modules/textual_inversion/dataset.py | 23 +++++++++++++++---- .../textual_inversion/textual_inversion.py | 7 +----- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3d3301b08..0128419bb 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -416,7 +416,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory) + + latent_sampling_method = ds.latent_sampling_method + + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 110c0e09b..f470324a7 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -138,9 +138,12 @@ class PersonalizedBase(Dataset): return entry class PersonalizedDataLoader(DataLoader): - def __init__(self, *args, **kwargs): - super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs) - self.collate_fn = collate_wrapper + def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): + super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + if latent_sampling_method == "random": + self.collate_fn = collate_wrapper_random + else: + self.collate_fn = collate_wrapper class BatchLoader: @@ -148,10 +151,22 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + #self.emb_index = [entry.emb_index for entry in data] + #print(self.latent_sample.device) def pin_memory(self): self.latent_sample = self.latent_sample.pin_memory() return self def collate_wrapper(batch): - return BatchLoader(batch) \ No newline at end of file + return BatchLoader(batch) + +class BatchLoaderRandom(BatchLoader): + def __init__(self, data): + super().__init__(data) + + def pin_memory(self): + return self + +def collate_wrapper_random(batch): + return BatchLoaderRandom(batch) \ No newline at end of file diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1d5e3a322..3036e48a3 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -277,7 +277,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ latent_sampling_method = ds.latent_sampling_method - dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False) + dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.sd_model.first_stage_model.to(devices.cpu) @@ -333,11 +333,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") - #scaler.unscale_(optimizer) - #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") - #torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0) - #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}") scaler.step(optimizer) scaler.update() embedding.step += 1 From 9ae30b34508770c941f4d09187c0d8e82bb009c8 Mon Sep 17 00:00:00 2001 From: dtlnor Date: Mon, 21 Nov 2022 12:53:55 +0900 Subject: [PATCH 28/72] remove cmd args requirement for deepdanbooru --- modules/ui.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index e6da1b2ae..aba13926e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -478,9 +478,7 @@ def create_toprow(is_img2img): if is_img2img: with gr.Column(scale=1, elem_id="interrogate_col"): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - - if cmd_opts.deepdanbooru: - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): with gr.Row(): @@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - if cmd_opts.deepdanbooru: - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], ) @@ -1240,7 +1237,7 @@ def create_ui(wrap_gradio_gpu_call): process_split = gr.Checkbox(label='Split oversized images') process_focal_crop = gr.Checkbox(label='Auto focal point crop') process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) with gr.Row(visible=False) as process_split_extra_row: process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) From e247b7400a592c0a19c197cd080aeec38ee02b68 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 17 Nov 2022 03:52:17 -0500 Subject: [PATCH 29/72] Add fixes for PyTorch 1.12.1 Fix typo "MasOS" -> "macOS" If MPS is available and PyTorch is an earlier version than 1.13: * Monkey patch torch.Tensor.to to ensure all tensors sent to MPS are contiguous * Monkey patch torch.nn.functional.layer_norm to ensure input tensor is contiguous (required for this program to work with MPS on unmodified PyTorch 1.12.1) --- modules/devices.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index a87d0d4c9..6e8277e58 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,9 +2,10 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: if not getattr(torch, 'has_mps', False): @@ -94,3 +95,28 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix From 563ea3f6ff66e0eba44033163d24e42297465a47 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 21 Nov 2022 02:56:00 -0500 Subject: [PATCH 30/72] Change .cuda() to .to(devices.device) --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index b9066d817..31ec7e171 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -58,7 +58,7 @@ class DeepDanbooru: a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 with torch.no_grad(), devices.autocast(): - x = torch.from_numpy(a).cuda() + x = torch.from_numpy(a).to(devices.device) y = self.model(x)[0].detach().cpu().numpy() probability_dict = {} From 0efffbb407a9d07eae6850374099775385ce176c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 21 Nov 2022 14:04:25 +0100 Subject: [PATCH 31/72] Supporting `*.safetensors` format. If a model file exists with extension `.safetensors` then we can load it more safely than with PyTorch weights. --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf030..0164cc1b2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,7 +45,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -180,7 +180,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if checkpoint_file.endswith(".safetensors"): + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") From 162fef394f8d80b54df7ede9e3b7ba65da23d3c5 Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:50:57 -0500 Subject: [PATCH 32/72] Patch line ui endings --- modules/ui.py | 3628 ++++++++++++++++++++++++------------------------- 1 file changed, 1814 insertions(+), 1814 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index a2b06aae7..54d3293a8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,1814 +1,1814 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin - - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -if cmd_opts.deepdanbooru: - from modules.deepbooru import get_deepbooru_tags - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.ldsr_model -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok != None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -art_symbol = '\U0001f3a8' # 🎨 -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") - -def save_pil_to_file(pil_image, dir=None): - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in pil_image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True - - file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) - pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) - return file_obj - - -# override save to file function so that it also writes PNG info -gr.processing_utils.save_pil_to_file = save_pil_to_file - - -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats - if run_memmon: - shared.mem_mon.monitor() - t = time.perf_counter() - - try: - res = list(func(*args, **kwargs)) - except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) - - shared.state.job = "" - shared.state.job_count = 0 - - if extra_outputs_array is None: - extra_outputs_array = [None, ''] - - res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] - - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - - if not add_stats: - return tuple(res) - - elapsed = time.perf_counter() - t - elapsed_m = int(elapsed // 60) - elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" - if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text - - if run_memmon: - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) - - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" - else: - vram_html = '' - - # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - - return tuple(res) - - return f - - -def calc_time_left(progress, threshold, label, force_display): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(): - with gr.Row(): - with gr.Box(): - with gr.Row(elem_id='seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id='random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') - - with gr.Box(elem_id='subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with gr.Row(visible=False) as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - with gr.Box(): - with gr.Row(elem_id='subseed_row'): - subseed = gr.Number(label='Variation seed', value=-1) - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id='random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) - - with gr.Row(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - - if cmd_opts.deepdanbooru: - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style.save_to_config = True - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style2.save_to_config = True - - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data[key] - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(): - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_folder_button = gr.Button(folder_symbol, elem_id=button_id) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - do_make_zip, - html_info, - ], - outputs=[ - download_files, - html_info, - html_info, - html_info, - ] - ) - else: - html_info_x = gr.HTML() - html_info = gr.HTML() - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info - - -def create_ui(wrap_gradio_gpu_call): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") - - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - - with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - enable_hr = gr.Checkbox(label='Highres. fix', value=False) - - with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) - - with gr.Row(equal_height=True): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) - - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - - with gr.Group(): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - firstphase_width, - firstphase_height, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - ) - - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - txt2img_prompt, - ], - outputs=[ - txt2img_prompt, - ] - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) - - with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - - with gr.Row(): - mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") - - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") - - with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) - - with gr.TabItem('Batch img2img', id='batch'): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) - - with gr.Row(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") - - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") - - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") - - with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - - with gr.Row(): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) - - with gr.Group(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - - with gr.Group(): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - if cmd_opts.deepdanbooru: - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - - roll.click( - fn=roll_artist, - _js="update_img2img_tokens", - inputs=[ - img2img_prompt, - ], - outputs=[ - img2img_prompt, - ] - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image'): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") - - with gr.TabItem('Batch Process'): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") - - with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") - show_extras_results = gr.Checkbox(label='Show result images', value=True) - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) - with gr.TabItem('Scale to'): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) - - result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False) - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Save as float16") - save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name") - initialization_text = gr.Textbox(label="Initialization text", value="*") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary') - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory') - process_dst = gr.Textbox(label='Destination directory') - process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - - batch_size = gr.Number(label='Batch size', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) - - with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') - train_embedding = gr.Button(value="Train Embedding", variant='primary') - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - dataset_directory, - log_directory, - training_width, - training_height, - steps, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary') - result = gr.HTML() - - settings_cols = 3 - items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') - - quicksettings_list = [] - - cols_displayed = 0 - items_displayed = 0 - previous_section = None - column = None - with gr.Row(elem_id="settings").style(equal_height=False): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): - if column is not None: - column.__exit__() - - column = gr.Column(variant='panel') - column.__enter__() - - items_displayed = 0 - cols_displayed += 1 - - previous_section = item.section - - elem_id, text = item.section - gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 - - with gr.Row(): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - if column is not None: - column.__exit__() - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in quicksettings_list: - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - save_as_safetensors, - custom_name, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - # Since there are many dropdowns that shouldn't be saved, - # we only mark dropdowns that should be saved. - if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - apply_field(x, 'visible') - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin + + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +if cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.ldsr_model +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +import modules.textual_inversion.ui +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok != None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +art_symbol = '\U0001f3a8' # 🎨 +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") + +def save_pil_to_file(pil_image, dir=None): + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats + if run_memmon: + shared.mem_mon.monitor() + t = time.perf_counter() + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + # When printing out our debug argument list, do not print out more than a MB of text + max_debug_str_len = 131072 # (1024*1024)/8 + + print("Error completing request", file=sys.stderr) + argStr = f"Arguments: {str(args)} {str(kwargs)}" + print(argStr[:max_debug_str_len], file=sys.stderr) + if len(argStr) > max_debug_str_len: + print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) + + print(traceback.format_exc(), file=sys.stderr) + + shared.state.job = "" + shared.state.job_count = 0 + + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if elapsed_m > 0: + elapsed_text = f"{elapsed_m}m "+elapsed_text + + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' + + # last item is always HTML + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + + return tuple(res) + + return f + + +def calc_time_left(progress, threshold, label, force_display): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and progress > 0.02) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def roll_artist(prompt): + allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) + artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) + + return prompt + ", " + artist.name if prompt != '' else artist.name + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = get_deepbooru_tags(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(): + with gr.Row(): + with gr.Box(): + with gr.Row(elem_id='seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id='random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') + + with gr.Box(elem_id='subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with gr.Row(visible=False) as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + with gr.Box(): + with gr.Row(elem_id='subseed_row'): + subseed = gr.Number(label='Variation seed', value=-1) + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id='random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) + + with gr.Row(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + + if cmd_opts.deepdanbooru: + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style.save_to_config = True + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_style2.save_to_config = True + + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data[key] + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(): + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' + open_folder_button = gr.Button(folder_symbol, elem_id=button_id) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + do_make_zip, + html_info, + ], + outputs=[ + download_files, + html_info, + html_info, + html_info, + ] + ) + else: + html_info_x = gr.HTML() + html_info = gr.HTML() + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info + + +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") + + with gr.Group(): + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + + with gr.Row(): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) + tiling = gr.Checkbox(label='Tiling', value=False) + enable_hr = gr.Checkbox(label='Highres. fix', value=False) + + with gr.Row(visible=False) as hr_options: + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + + with gr.Row(equal_height=True): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + + with gr.Group(): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + firstphase_width, + firstphase_height, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + ) + + roll.click( + fn=roll_artist, + _js="update_txt2img_tokens", + inputs=[ + txt2img_prompt, + ], + outputs=[ + txt2img_prompt, + ] + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (firstphase_width, "First pass size-1"), + (firstphase_height, "First pass size-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img'): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) + + with gr.TabItem('Inpaint', id='inpaint'): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) + + with gr.Row(): + mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") + + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") + + with gr.Row(): + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + + with gr.TabItem('Batch img2img', id='batch'): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + + with gr.Row(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") + + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + + with gr.Group(): + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") + + with gr.Row(): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) + tiling = gr.Checkbox(label='Tiling', value=False) + + with gr.Row(): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + with gr.Group(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) + + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + + with gr.Group(): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + if cmd_opts.deepdanbooru: + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + + roll.click( + fn=roll_artist, + _js="update_img2img_tokens", + inputs=[ + img2img_prompt, + ], + outputs=[ + img2img_prompt, + ] + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image'): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + + with gr.TabItem('Batch Process'): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + + with gr.TabItem('Batch from Directory'): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") + show_extras_results = gr.Checkbox(label='Show result images', value=True) + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by'): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) + with gr.TabItem('Scale to'): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) + + result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False) + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + custom_name = gr.Textbox(label="Custom Name (Optional)") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + save_as_half = gr.Checkbox(value=False, label="Save as float16") + save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name") + initialization_text = gr.Textbox(label="Initialization text", value="*") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary') + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory') + process_dst = gr.Textbox(label='Destination directory') + process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies') + process_split = gr.Checkbox(label='Split oversized images') + process_focal_crop = gr.Checkbox(label='Auto focal point crop') + process_caption = gr.Checkbox(label='Use BLIP for caption') + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_focal_crop_debug = gr.Checkbox(label='Create debug image') + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt") + run_preprocess = gr.Button(value="Preprocess", variant='primary') + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with gr.Row(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + with gr.Row(): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + + batch_size = gr.Number(label='Batch size', value=1, precision=0) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') + train_embedding = gr.Button(value="Train Embedding", variant='primary') + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + dataset_directory, + log_directory, + training_width, + training_height, + steps, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with gr.Row(variant="compact"): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + settings_submit = gr.Button(value="Apply settings", variant='primary') + result = gr.HTML() + + settings_cols = 3 + items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + + quicksettings_list = [] + + cols_displayed = 0 + items_displayed = 0 + previous_section = None + column = None + with gr.Row(elem_id="settings").style(equal_height=False): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): + if column is not None: + column.__exit__() + + column = gr.Column(variant='panel') + column.__enter__() + + items_displayed = 0 + cols_displayed += 1 + + previous_section = item.section + + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 + + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + if column is not None: + column.__exit__() + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in quicksettings_list: + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + save_as_safetensors, + custom_name, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + # Since there are many dropdowns that shouldn't be saved, + # we only mark dropdowns that should be saved. + if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + apply_field(x, 'visible') + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse From e134b74ce95773789f69d158d23e93b7fe9295dc Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Mon, 21 Nov 2022 10:58:57 -0500 Subject: [PATCH 33/72] Ignore safetensor files --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ee53044c5..21fa26a75 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__ *.ckpt +*.safetensors *.pth /ESRGAN/* /SwinIR/* From 210cb4c128afdd65fa998229a97d0694154983ea Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Mon, 21 Nov 2022 16:40:18 -0500 Subject: [PATCH 34/72] Use GPU for loading safetensors, disable export --- modules/sd_models.py | 5 +++-- modules/ui.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2f8c2c48e..2bbb3bf5a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,8 +147,9 @@ def torch_load(model_filename, model_info, map_override=None): map_override=shared.weight_load_location if not map_override else map_override if(checkpoint_types[model_info.exttype] == 'safetensors'): # safely load weights - # TODO: safetensors supports zero copy fast load to gpu, see issue #684 - return load_file(model_filename, device=map_override) + # TODO: safetensors supports zero copy fast load to gpu, see issue #684. + # GPU only for now, see https://github.com/huggingface/safetensors/issues/95 + return load_file(model_filename, device='cuda') else: return torch.load(model_filename, map_location=map_override) diff --git a/modules/ui.py b/modules/ui.py index 54d3293a8..c376a59df 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1187,7 +1187,8 @@ def create_ui(wrap_gradio_gpu_call): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") - save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format") + # invisible until feature can be verified + save_as_safetensors = gr.Checkbox(value=False, label="Save as safetensors format", visible=False) modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): From c27a973c82c374f47fb279c1b9b8de7288fd729d Mon Sep 17 00:00:00 2001 From: Rogerooo Date: Tue, 22 Nov 2022 14:02:59 +0000 Subject: [PATCH 35/72] fix null negative_prompt on get requests Small typo that causes a bug when returning negative prompts from the get request. --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7a567be38..08e03c13f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -305,7 +305,7 @@ class Api: styleList = [] for k in shared.prompt_styles.styles: style = shared.prompt_styles.styles[k] - styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) + styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]}) return styleList From 6ecf72b6f72220849d3dcdc27a28bf07bacc73f4 Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Tue, 22 Nov 2022 14:24:10 +0000 Subject: [PATCH 36/72] Update k-diffusion to Release 0.0.11 --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index d2f1055cc..497ea32bb 100644 --- a/launch.py +++ b/launch.py @@ -145,7 +145,7 @@ def prepare_enviroment(): stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "b325595b8a776d483f6935dfa7b45f01c27039e4) codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") From 0a01f5089127f1ab86625036526082f544344a10 Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Tue, 22 Nov 2022 14:24:50 +0000 Subject: [PATCH 37/72] Add DPM++ SDE sampler --- modules/sd_samplers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4fe678544..80e91d622 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -26,6 +26,7 @@ samplers_k_diffusion = [ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), @@ -33,6 +34,7 @@ samplers_k_diffusion = [ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), ] samplers_data_k_diffusion = [ From 3c3c46be5f04480fee2a7afea77242f101eea00a Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Tue, 22 Nov 2022 14:25:39 +0000 Subject: [PATCH 38/72] Update requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 762db4f34..e4e5ec642 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia lark inflection GitPython +torchsde From 47ce73fbbf9fb361b1e12031934cefd81f3b2364 Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Tue, 22 Nov 2022 14:26:09 +0000 Subject: [PATCH 39/72] Update requirements_versions.txt --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index 662ca6849..8d557fe38 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -25,3 +25,4 @@ kornia==0.6.7 lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 +torchsde==0.2.5 From 45fd785436068f3b1c09fb7bc575118b6059fc7b Mon Sep 17 00:00:00 2001 From: uservar <63248296+uservar@users.noreply.github.com> Date: Tue, 22 Nov 2022 14:52:16 +0000 Subject: [PATCH 40/72] Update launch.py --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 497ea32bb..3608c1039 100644 --- a/launch.py +++ b/launch.py @@ -145,7 +145,7 @@ def prepare_enviroment(): stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") - k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "b325595b8a776d483f6935dfa7b45f01c27039e4) + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "b325595b8a776d483f6935dfa7b45f01c27039e4") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") From ac90cf38c6b55d57d37923aa1fe86c7374e32d0b Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:13:07 -0500 Subject: [PATCH 41/72] safetensors optional for now --- modules/sd_models.py | 9 ++++++++- requirements.txt | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2bbb3bf5a..75f7ab094 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys import gc from collections import namedtuple import torch -from safetensors.torch import load_file, save_file import re from omegaconf import OmegaConf @@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None): # safely load weights # TODO: safetensors supports zero copy fast load to gpu, see issue #684. # GPU only for now, see https://github.com/huggingface/safetensors/issues/95 + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") return load_file(model_filename, device='cuda') else: return torch.load(model_filename, map_location=map_override) @@ -157,6 +160,10 @@ def torch_save(model, output_filename): basename, exttype = os.path.splitext(output_filename) if(checkpoint_types[exttype] == 'safetensors'): # [===== >] Reticulating brines... + try: + from safetensors.torch import save_file + except ImportError as e: + raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}") save_file(model, output_filename, metadata={"format": "pt"}) else: torch.save(model, output_filename) diff --git a/requirements.txt b/requirements.txt index f7de9f707..762db4f34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,3 @@ kornia lark inflection GitPython -safetensors From 89d8ecff09b426ddc89eb5b432825f8f4c218051 Mon Sep 17 00:00:00 2001 From: flamelaw Date: Wed, 23 Nov 2022 02:49:01 +0900 Subject: [PATCH 42/72] small fixes --- modules/hypernetworks/hypernetwork.py | 6 +++--- modules/textual_inversion/textual_inversion.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 0128419bb..4541af186 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -435,8 +435,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, optimizer_name = hypernetwork.optimizer_name else: print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!") - optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) - optimizer_name = 'AdamW' + optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate) + optimizer_name = 'AdamW' if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer. try: @@ -582,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.state.textinfo = f"""

Loss: {loss_step:.7f}
-Step: {hypernetwork.step}
+Step: {steps_done}
Last prompt: {html.escape(batch.cond_text[0])}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3036e48a3..fee08e33e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -436,7 +436,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ shared.state.textinfo = f"""

Loss: {loss_step:.7f}
-Step: {embedding.step}
+Step: {steps_done}
Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}
From 75b67eebf21f72f5b693926476d9c3b12471f0d6 Mon Sep 17 00:00:00 2001 From: Sena <34237511+sena-nana@users.noreply.github.com> Date: Wed, 23 Nov 2022 17:43:58 +0800 Subject: [PATCH 43/72] Fix bare base64 not accept --- modules/api/api.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7a567be38..648bd6a86 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -3,6 +3,7 @@ import io import time import uvicorn from threading import Lock +from io import BytesIO from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPBasic, HTTPBasicCredentials @@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo -from PIL import PngImagePlugin +from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -133,7 +134,10 @@ class Api: mask = img2imgreq.mask if mask: - mask = decode_base64_to_image(mask) + if mask.startswith("data:image/"): + mask = decode_base64_to_image(mask) + else: + mask = Image.open(BytesIO(base64.b64decode(mask))) populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, @@ -147,7 +151,10 @@ class Api: imgs = [] for img in init_images: - img = decode_base64_to_image(img) + if img.startswith("data:image/"): + img = decode_base64_to_image(img) + else: + img = Image.open(BytesIO(base64.b64decode(img))) imgs = [img] * p.batch_size p.init_images = imgs From d2c97fc3fe5857d6fba9ad1695ed3ac6ec455ca9 Mon Sep 17 00:00:00 2001 From: flamelaw Date: Wed, 23 Nov 2022 20:00:00 +0900 Subject: [PATCH 44/72] fix dropout, implement train/eval mode --- modules/hypernetworks/hypernetwork.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4541af186..9388959f5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -154,16 +154,28 @@ class Hypernetwork: HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), ) + self.eval_mode() def weights(self): res = [] + for k, layers in self.layers.items(): + for layer in layers: + res += layer.parameters() + return res + def train_mode(self): for k, layers in self.layers.items(): for layer in layers: layer.train() - res += layer.trainables() + for param in layer.parameters(): + param.requires_grad = True - return res + def eval_mode(self): + for k, layers in self.layers.items(): + for layer in layers: + layer.eval() + for param in layer.parameters(): + param.requires_grad = False def save(self, filename): state_dict = {} @@ -426,8 +438,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu) weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True + hypernetwork.train_mode() # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: @@ -538,7 +549,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) - + hypernetwork.eval_mode() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) @@ -571,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - + hypernetwork.train_mode() if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) @@ -593,6 +604,7 @@ Last saved image: {html.escape(last_saved_image)}
finally: pbar.leave = False pbar.close() + hypernetwork.eval_mode() #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') From 1bd57cc9791e2e742f72a3d74d589f2c289e8e92 Mon Sep 17 00:00:00 2001 From: flamelaw Date: Wed, 23 Nov 2022 20:21:52 +0900 Subject: [PATCH 45/72] last_layer_dropout default to False --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9388959f5..8466887f6 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True): + add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" From 6001684be3e7b023346326b9dfc771219b8fe47e Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Wed, 23 Nov 2022 06:35:44 -0800 Subject: [PATCH 46/72] add model_name pattern for saving --- javascript/hints.js | 4 ++-- modules/images.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 623bc25cd..ac417ff65 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -62,8 +62,8 @@ titles = { "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", - "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime], [datetime

{html.escape(type(e).__name__+': '+str(e))}
"] + + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if elapsed_m > 0: + elapsed_text = f"{elapsed_m}m "+elapsed_text + + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' + + # last item is always HTML + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + + return tuple(res) + + return f + diff --git a/modules/ui.py b/modules/ui.py index 446bee409..008093614 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -17,7 +17,7 @@ import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin - +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules.paths import script_path @@ -158,67 +158,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats - if run_memmon: - shared.mem_mon.monitor() - t = time.perf_counter() - - try: - res = list(func(*args, **kwargs)) - except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) - - shared.state.job = "" - shared.state.job_count = 0 - - if extra_outputs_array is None: - extra_outputs_array = [None, ''] - - res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] - - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - - if not add_stats: - return tuple(res) - - elapsed = time.perf_counter() - t - elapsed_m = int(elapsed // 60) - elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" - if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text - - if run_memmon: - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) - - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" - else: - vram_html = '' - - # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - - return tuple(res) - - return f def calc_time_left(progress, threshold, label, force_display): @@ -666,7 +605,7 @@ Requested path was: {f} return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info -def create_ui(wrap_gradio_gpu_call): +def create_ui(): import modules.img2img import modules.txt2img @@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call): height, ] - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) diff --git a/webui.py b/webui.py index 7a56bde80..16e7ec1a6 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir @@ -32,38 +33,12 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork -queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None -def wrap_queued_call(func): - def f(*args, **kwargs): - with queue_lock: - res = func(*args, **kwargs) - - return res - - return f - - -def wrap_gradio_gpu_call(func, extra_outputs=None): - def f(*args, **kwargs): - - shared.state.begin() - - with queue_lock: - res = func(*args, **kwargs) - - shared.state.end() - - return res - - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) - - def initialize(): extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) @@ -159,7 +134,7 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() - shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, @@ -189,6 +164,7 @@ def webui(): create_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo)