diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml new file mode 100644 index 000000000..5270cba4c --- /dev/null +++ b/.github/workflows/on_pull_request.yaml @@ -0,0 +1,36 @@ +# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file +name: Run Linting/Formatting on Pull Requests + +on: + - push + - pull_request + # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs + # if you want to filter out branches, delete the `- pull_request` and uncomment these lines : + # pull_request: + # branches: + # - master + # branches-ignore: + # - development + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: 3.10.6 + - name: Install PyLint + run: | + python -m pip install --upgrade pip + pip install pylint + # This lets PyLint check to see if it can resolve imports + - name: Install dependencies + run : | + export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" + python launch.py + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 000000000..53254e5dc --- /dev/null +++ b/.pylintrc @@ -0,0 +1,3 @@ +# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html +[MESSAGES CONTROL] +disable=C,R,W,E,I diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 5aac57f77..070cf255f 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); - if ( !imgWrap ) { + if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) { return; } e.stopPropagation(); @@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => { window.document.addEventListener('drop', e => { const target = e.composedPath()[0]; + if (target.placeholder.indexOf("Prompt") == -1) { + return; + } const imgWrap = target.closest('[data-testid="image"]'); if ( !imgWrap ) { return; diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 3f1d2fbbd..67084e7a6 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -2,6 +2,8 @@ addEventListener('keydown', (event) => { let target = event.originalTarget || event.composedPath()[0]; if (!target.hasAttribute("placeholder")) return; if (!target.placeholder.toLowerCase().includes("prompt")) return; + if (! (event.metaKey || event.ctrlKey)) return; + let plus = "ArrowUp" let minus = "ArrowDown" diff --git a/javascript/hints.js b/javascript/hints.js index af010a59d..b98012f5e 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -16,6 +16,8 @@ titles = { "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", + "\u{1f4be}": "Save style", + "\u{1f4cb}": "Apply selected styles to current prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", @@ -87,8 +89,8 @@ titles = { "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.", - "Weighted Sum": "Result = A * (1 - M) + B * M", - "Add difference": "Result = A + (B - C) * (1 - M)", + "Weighted sum": "Result = A * (1 - M) + B * M", + "Add difference": "Result = A + (B - C) * M", } diff --git a/javascript/imageParams.js b/javascript/imageParams.js new file mode 100644 index 000000000..67404a89b --- /dev/null +++ b/javascript/imageParams.js @@ -0,0 +1,19 @@ +window.onload = (function(){ + window.addEventListener('drop', e => { + const target = e.composedPath()[0]; + const idx = selected_gallery_index(); + if (target.placeholder.indexOf("Prompt") == -1) return; + + let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + + e.stopPropagation(); + e.preventDefault(); + const imgParent = gradioApp().getElementById(prompt_target); + const files = e.dataTransfer.files; + const fileInput = imgParent.querySelector('input[type="file"]'); + if ( fileInput ) { + fileInput.files = files; + fileInput.dispatchEvent(new Event('change')); + } + }); +}); diff --git a/javascript/images_history.js b/javascript/images_history.js new file mode 100644 index 000000000..f7d052c3d --- /dev/null +++ b/javascript/images_history.js @@ -0,0 +1,206 @@ +var images_history_click_image = function(){ + if (!this.classList.contains("transform")){ + var gallery = images_history_get_parent_by_class(this, "images_history_cantainor"); + var buttons = gallery.querySelectorAll(".gallery-item"); + var i = 0; + var hidden_list = []; + buttons.forEach(function(e){ + if (e.style.display == "none"){ + hidden_list.push(i); + } + i += 1; + }) + if (hidden_list.length > 0){ + setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); + } + } + images_history_set_image_info(this); +} + +var images_history_click_tab = function(){ + var tabs_box = gradioApp().getElementById("images_history_tab"); + if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { + gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click(); + tabs_box.classList.add(this.getAttribute("tabname")) + } +} + +function images_history_disabled_del(){ + gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ + btn.setAttribute('disabled','disabled'); + }); +} + +function images_history_get_parent_by_class(item, class_name){ + var parent = item.parentElement; + while(!parent.classList.contains(class_name)){ + parent = parent.parentElement; + } + return parent; +} + +function images_history_get_parent_by_tagname(item, tagname){ + var parent = item.parentElement; + tagname = tagname.toUpperCase() + while(parent.tagName != tagname){ + console.log(parent.tagName, tagname) + parent = parent.parentElement; + } + return parent; +} + +function images_history_hide_buttons(hidden_list, gallery){ + var buttons = gallery.querySelectorAll(".gallery-item"); + var num = 0; + buttons.forEach(function(e){ + if (e.style.display == "none"){ + num += 1; + } + }); + if (num == hidden_list.length){ + setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); + } + for( i in hidden_list){ + buttons[hidden_list[i]].style.display = "none"; + } +} + +function images_history_set_image_info(button){ + var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item"); + var index = -1; + var i = 0; + buttons.forEach(function(e){ + if(e == button){ + index = i; + } + if(e.style.display != "none"){ + i += 1; + } + }); + var gallery = images_history_get_parent_by_class(button, "images_history_cantainor"); + var set_btn = gallery.querySelector(".images_history_set_index"); + var curr_idx = set_btn.getAttribute("img_index", index); + if (curr_idx != index) { + set_btn.setAttribute("img_index", index); + images_history_disabled_del(); + } + set_btn.click(); + +} + +function images_history_get_current_img(tabname, image_path, files){ + return [ + gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), + image_path, + files + ]; +} + +function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){ + image_index = parseInt(image_index); + var tab = gradioApp().getElementById(tabname + '_images_history'); + var set_btn = tab.querySelector(".images_history_set_index"); + var buttons = []; + tab.querySelectorAll(".gallery-item").forEach(function(e){ + if (e.style.display != 'none'){ + buttons.push(e); + } + }); + var img_num = buttons.length / 2; + if (img_num <= del_num){ + setTimeout(function(tabname){ + gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + }, 30, tabname); + } else { + var next_img + for (var i = 0; i < del_num; i++){ + if (image_index + i < image_index + img_num){ + buttons[image_index + i].style.display = 'none'; + buttons[image_index + img_num + 1].style.display = 'none'; + next_img = image_index + i + 1 + } + } + var bnt; + if (next_img >= img_num){ + btn = buttons[image_index - del_num]; + } else { + btn = buttons[next_img]; + } + setTimeout(function(btn){btn.click()}, 30, btn); + } + images_history_disabled_del(); + return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index]; +} + +function images_history_turnpage(img_path, page_index, image_index, tabname){ + var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); + buttons.forEach(function(elem) { + elem.style.display = 'block'; + }) + return [img_path, page_index, image_index, tabname]; +} + +function images_history_enable_del_buttons(){ + gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ + btn.removeAttribute('disabled'); + }) +} + +function images_history_init(){ + var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page') + if (load_txt2img_button){ + for (var i in images_history_tab_list ){ + tab = images_history_tab_list[i]; + gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); + gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); + gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); + gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); + + } + var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); + tabs_box.setAttribute("id", "images_history_tab"); + var tab_btns = tabs_box.querySelectorAll("button"); + for (var i in images_history_tab_list){ + var tabname = images_history_tab_list[i] + tab_btns[i].setAttribute("tabname", tabname); + + // this refreshes history upon tab switch + // until the history is known to work well, which is not the case now, we do not do this at startup + //tab_btns[i].addEventListener('click', images_history_click_tab); + } + tabs_box.classList.add(images_history_tab_list[0]); + + // same as above, at page load + //load_txt2img_button.click(); + } else { + setTimeout(images_history_init, 500); + } +} + +var images_history_tab_list = ["txt2img", "img2img", "extras"]; +setTimeout(images_history_init, 500); +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m){ + for (var i in images_history_tab_list ){ + let tabname = images_history_tab_list[i] + var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); + buttons.forEach(function(bnt){ + bnt.addEventListener('click', images_history_click_image, true); + }); + + // same as load_txt2img_button.click() above + /* + var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); + if (cls_btn){ + cls_btn.addEventListener('click', function(){ + gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + }, false); + }*/ + + } + }); + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); + +}); + + diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 4395a2159..076f0a973 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,5 +1,7 @@ // code related to showing and updating progressbar shown as the image is being made global_progressbars = {} +galleries = {} +galleryObservers = {} function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ var progressbar = gradioApp().getElementById(id_progressbar) @@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip preview.style.width = gallery.clientWidth + "px" preview.style.height = gallery.clientHeight + "px" + //only watch gallery if there is a generation process going on + check_gallery(id_gallery); + var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; if(!progressDiv){ if (skip) { skip.style.display = "none" } interrupt.style.display = "none" + + //disconnect observer once generation finished, so user can close selected image if they want + if (galleryObservers[id_gallery]) { + galleryObservers[id_gallery].disconnect(); + galleries[id_gallery] = null; + } } + + } window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) @@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip } } +function check_gallery(id_gallery){ + let gallery = gradioApp().getElementById(id_gallery) + // if gallery has no change, no need to setting up observer again. + if (gallery && galleries[id_gallery] !== gallery){ + galleries[id_gallery] = gallery; + if(galleryObservers[id_gallery]){ + galleryObservers[id_gallery].disconnect(); + } + let prevSelectedIndex = selected_gallery_index(); + galleryObservers[id_gallery] = new MutationObserver(function (){ + let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') + let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') + if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { + //automatically re-open previously selected index (if exists) + galleryButtons[prevSelectedIndex].click(); + showGalleryImage(); + } + }) + galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) + } +} + onUiUpdate(function(){ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') diff --git a/javascript/ui.js b/javascript/ui.js index 0f8fe68ef..9e1bed4cb 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -141,7 +141,7 @@ function submit_img2img(){ function ask_for_style_name(_, prompt_text, negative_prompt_text) { name_ = prompt('Style name:') - return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text] + return [name_, prompt_text, negative_prompt_text] } @@ -187,12 +187,10 @@ onUiUpdate(function(){ if (!txt2img_textarea) { txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); - txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate")); } if (!img2img_textarea) { img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); - img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate")); } }) @@ -220,14 +218,6 @@ function update_token_counter(button_id) { token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); } -function submit_prompt(event, generate_button_id) { - if (event.altKey && event.keyCode === 13) { - event.preventDefault(); - gradioApp().getElementById(generate_button_id).click(); - return; - } -} - function restart_reload(){ document.body.innerHTML='

Reloading...

'; setTimeout(function(){location.reload()},2000) diff --git a/launch.py b/launch.py index 16627a032..537670a39 100644 --- a/launch.py +++ b/launch.py @@ -9,6 +9,7 @@ import platform dir_repos = "repositories" python = sys.executable git = os.environ.get('GIT', "git") +index_url = os.environ.get('INDEX_URL', "") def extract_arg(args, name): @@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None): def run_pip(args, desc=None): - return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") + index_url_line = f' --index-url {index_url}' if index_url != '' else '' + return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") def check_run_python(code): @@ -76,7 +78,7 @@ def git_clone(url, dir, name, commithash=None): return run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") diff --git a/modules/deepbooru.py b/modules/deepbooru.py index f34f37882..4ad334a12 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -102,7 +102,7 @@ def get_deepbooru_tags_model(): tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( - model_path, compile_model=True + model_path, compile_model=False ) return model, tags diff --git a/modules/devices.py b/modules/devices.py index 03ef58f19..eb4225834 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,7 +34,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/extras.py b/modules/extras.py index 532d869f5..f2f5a7b04 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,24 +159,12 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): - # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) - # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, theta2, alpha): - alpha = alpha * alpha * (3 - (2 * alpha)) - return theta0 + ((theta1 - theta0) * alpha) - - # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, theta2, alpha): - import math - alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) - return theta0 + ((theta1 - theta0) * alpha) - def add_difference(theta0, theta1, theta2, alpha): - return theta0 + (theta1 - theta2) * (1.0 - alpha) + return theta0 + (theta1 - theta2) * alpha primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -198,9 +186,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_2 = None theta_funcs = { - "Weighted Sum": weighted_sum, - "Sigmoid": sigmoid, - "Inverse Sigmoid": inv_sigmoid, + "Weighted sum": weighted_sum, "Add difference": add_difference, } theta_func = theta_funcs[interp_method] @@ -209,7 +195,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + t2 = (theta_2 or {}).get(key) + if t2 is None: + t2 = torch.zeros_like(theta_0[key]) + + theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) + if save_as_half: theta_0[key] = theta_0[key].half() @@ -222,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = filename if custom_name == '' else (custom_name + '.ckpt') output_modelname = os.path.join(ckpt_dir, filename) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index f1248bb7b..a2b3bc0af 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -5,6 +5,7 @@ import os import sys import traceback import tqdm +import csv import torch @@ -14,6 +15,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset +from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -180,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): return self.to_out(out) -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): +def stack_conds(conds): + if len(conds) == 1: + return torch.stack(conds) + + # same as in reconstruct_multicond_batch + token_count = max([x.shape[0] for x in conds]) + for i in range(len(conds)): + if conds[i].shape[0] != token_count: + last_vector = conds[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1]) + conds[i] = torch.vstack([conds[i], last_vector_repeated]) + + return torch.stack(conds) + +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -209,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -233,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entry in pbar: + for i, entries in pbar: hypernetwork.step = i + ititial_step scheduler.apply(optimizer, hypernetwork.step) @@ -244,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - cond = entry.cond.to(devices.device) - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), cond)[0] + c = stack_conds([entry.cond for entry in entries]).to(devices.device) +# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] del x - del cond + del c losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -262,23 +279,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt - optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=preview_text, - steps=20, do_not_save_grid=True, do_not_save_samples=True, ) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + processed = processing.process_images(p) image = processed.images[0] if len(processed.images)>0 else None @@ -297,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/images.py b/modules/images.py index c0a906762..b9589563a 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,4 +1,5 @@ import datetime +import io import math import os from collections import namedtuple @@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None): rows = opts.n_rows elif opts.n_rows == 0: rows = batch_size + elif opts.grid_prevent_empty_spots: + rows = math.floor(math.sqrt(len(imgs))) + while len(imgs) % rows != 0: + rows -= 1 else: rows = math.sqrt(len(imgs)) rows = round(rows) @@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn = None return fullfn, txt_fullfn + + +def image_data(data): + try: + image = Image.open(io.BytesIO(data)) + textinfo = image.text["parameters"] + return textinfo, None + except Exception: + pass + + try: + text = data.decode('utf8') + assert len(text) < 10000 + return text, None + + except Exception: + pass + + return '', None diff --git a/modules/images_history.py b/modules/images_history.py new file mode 100644 index 000000000..9260df8a8 --- /dev/null +++ b/modules/images_history.py @@ -0,0 +1,181 @@ +import os +import shutil + + +def traverse_all_files(output_dir, image_list, curr_dir=None): + curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) + try: + f_list = os.listdir(curr_path) + except: + if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": + image_list.append(curr_dir) + return image_list + for file in f_list: + file = file if curr_dir is None else os.path.join(curr_dir, file) + file_path = os.path.join(curr_path, file) + if file[-4:] == ".txt": + pass + elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: + image_list.append(file) + else: + image_list = traverse_all_files(output_dir, image_list, file) + return image_list + + +def get_recent_images(dir_name, page_index, step, image_index, tabname): + page_index = int(page_index) + f_list = os.listdir(dir_name) + image_list = [] + image_list = traverse_all_files(dir_name, image_list) + image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + num = 48 if tabname != "extras" else 12 + max_page_index = len(image_list) // num + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num + image_list = image_list[idx_frm:idx_frm + num] + image_index = int(image_index) + if image_index < 0 or image_index > len(image_list) - 1: + current_file = None + hidden = None + else: + current_file = image_list[int(image_index)] + hidden = os.path.join(dir_name, current_file) + return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + + +def first_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, 1, 0, image_index, tabname) + + +def end_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, -1, 0, image_index, tabname) + + +def prev_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, -1, image_index, tabname) + + +def next_page_click(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, 1, image_index, tabname) + + +def page_index_change(dir_name, page_index, image_index, tabname): + return get_recent_images(dir_name, page_index, 0, image_index, tabname) + + +def show_image_info(num, image_path, filenames): + # print(f"select image {num}") + file = filenames[int(num)] + return file, num, os.path.join(image_path, file) + + +def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): + if name == "": + return filenames, delete_num + else: + delete_num = int(delete_num) + index = list(filenames).index(name) + i = 0 + new_file_list = [] + for name in filenames: + if i >= index and i < index + delete_num: + path = os.path.join(dir_name, name) + if os.path.exists(path): + print(f"Delete file {path}") + os.remove(path) + txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(txt_file): + os.remove(txt_file) + else: + print(f"Not exists file {path}") + else: + new_file_list.append(name) + i += 1 + return new_file_list, 1 + + +def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): + if opts.outdir_samples != "": + dir_name = opts.outdir_samples + elif tabname == "txt2img": + dir_name = opts.outdir_txt2img_samples + elif tabname == "img2img": + dir_name = opts.outdir_img2img_samples + elif tabname == "extras": + dir_name = opts.outdir_extras_samples + d = dir_name.split("/") + dir_name = "/" if dir_name.startswith("/") else d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + with gr.Row(): + renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + with gr.Row(elem_id=tabname + "_images_history"): + with gr.Row(): + with gr.Column(scale=2): + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") + with gr.Column(): + with gr.Row(): + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(label="File Name", interactive=False) + with gr.Row(): + # hiden items + + img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) + tabname_box = gr.Textbox(tabname, visible=False) + image_index = gr.Textbox(value=-1, visible=False) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) + filenames = gr.State() + hidden = gr.Image(type="pil", visible=False) + info1 = gr.Textbox(visible=False) + info2 = gr.Textbox(visible=False) + + # turn pages + gallery_inputs = [img_path, page_index, image_index, tabname_box] + gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] + + first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) + # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) + + # other funcitons + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) + delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + + # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) + switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') + switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + +def create_history_tabs(gr, opts, run_pnginfo, switch_dict): + with gr.Blocks(analytics_enabled=False) as images_history: + with gr.Tabs() as tabs: + with gr.Tab("txt2img history"): + with gr.Blocks(analytics_enabled=False) as images_history_txt2img: + show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) + with gr.Tab("img2img history"): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) + with gr.Tab("extras history"): + with gr.Blocks(analytics_enabled=False) as images_history_img2img: + show_images_history(gr, opts, "extras", run_pnginfo, switch_dict) + return images_history diff --git a/modules/interrogate.py b/modules/interrogate.py index af858cc09..9263d65a6 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -55,7 +55,7 @@ class InterrogateModels: model, preprocess = clip.load(clip_model_name) model.eval() - model = model.to(shared.device) + model = model.to(devices.device_interrogate) return model, preprocess @@ -65,14 +65,14 @@ class InterrogateModels: if not shared.cmd_opts.no_half: self.blip_model = self.blip_model.half() - self.blip_model = self.blip_model.to(shared.device) + self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() if not shared.cmd_opts.no_half: self.clip_model = self.clip_model.half() - self.clip_model = self.clip_model.to(shared.device) + self.clip_model = self.clip_model.to(devices.device_interrogate) self.dtype = next(self.clip_model.parameters()).dtype @@ -99,11 +99,11 @@ class InterrogateModels: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] top_count = min(top_count, len(text_array)) - text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device) + text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = torch.zeros((1, len(text_array))).to(shared.device) + similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] @@ -116,7 +116,7 @@ class InterrogateModels: transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) @@ -140,7 +140,7 @@ class InterrogateModels: res = caption - clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): diff --git a/modules/processing.py b/modules/processing.py index ab68d63a4..1db26c3ea 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -145,9 +145,8 @@ class Processed: self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] - self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) - self.subseed = int( - self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 + self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -541,16 +540,15 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - firstphase_width = 0 - firstphase_height = 0 - firstphase_width_truncated = 0 - firstphase_height_truncated = 0 - def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs): + def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr - self.scale_latent = scale_latent self.denoising_strength = denoising_strength + self.firstphase_width = firstphase_width + self.firstphase_height = firstphase_height + self.truncate_x = 0 + self.truncate_y = 0 def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -559,14 +557,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) + if self.firstphase_width == 0 or self.firstphase_height == 0: + desired_pixel_count = 512 * 512 + actual_pixel_count = self.width * self.height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + self.firstphase_width = math.ceil(scale * self.width / 64) * 64 + self.firstphase_height = math.ceil(scale * self.height / 64) * 64 + firstphase_width_truncated = int(scale * self.width) + firstphase_height_truncated = int(scale * self.height) + + else: + + width_ratio = self.width / self.firstphase_width + height_ratio = self.height / self.firstphase_height + + if width_ratio > height_ratio: + firstphase_width_truncated = self.firstphase_width + firstphase_height_truncated = self.firstphase_width * self.height / self.width + else: + firstphase_width_truncated = self.firstphase_height * self.width / self.height + firstphase_height_truncated = self.firstphase_height + + self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" + self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f + self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - self.firstphase_width_truncated = int(scale * self.width) - self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -585,37 +600,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f - truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f + samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] - samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2, - truncate_x // 2:samples.shape[3] - truncate_x // 2] - - if self.scale_latent: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), - mode="bilinear") + if opts.use_scale_latent_for_hires_fix: + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) + lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), - mode="bilinear") - else: - lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) + batch_images = [] + for i, x_sample in enumerate(lowres_samples): + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + image = images.resize_image(0, image, self.width, self.height) + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + batch_images.append(image) - batch_images = [] - for i, x_sample in enumerate(lowres_samples): - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - image = images.resize_image(0, image, self.width, self.height) - image = np.array(image).astype(np.float32) / 255.0 - image = np.moveaxis(image, 2, 0) - batch_images.append(image) - - decoded_samples = torch.from_numpy(np.array(batch_images)) - decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. + decoded_samples = torch.from_numpy(np.array(batch_images)) + decoded_samples = decoded_samples.to(shared.device) + decoded_samples = 2. * decoded_samples - 1. samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) diff --git a/modules/safe.py b/modules/safe.py index 20be16a50..399165a19 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -96,11 +96,18 @@ def load(filename, *args, **kwargs): if not shared.cmd_opts.disable_safe_unpickle: check_pt(filename) + except pickle.UnpicklingError: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + return None + except Exception: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) return None return unsafe_torch_load(filename, *args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a55b4c32..3aa21ec14 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,4 @@ -import glob +import collections import os.path import sys from collections import namedtuple @@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir)) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} +checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -132,38 +133,45 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + if checkpoint_info not in checkpoints_loaded: + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") - sd = get_state_dict_from_checkpoint(pl_sd) + sd = get_state_dict_from_checkpoint(pl_sd) + model.load_state_dict(sd, strict=False) - model.load_state_dict(sd, strict=False) + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + model.half() - if not shared.cmd_opts.no_half: - model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: + vae_file = shared.cmd_opts.vae_path - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict) - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.to(devices.dtype_vae) - model.first_stage_model.load_state_dict(vae_dict) - - model.first_stage_model.to(devices.dtype_vae) + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + else: + print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) + model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file @@ -202,6 +210,7 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/shared.py b/modules/shared.py index 7cd608cae..3c5ffef1e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -36,6 +36,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") +parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") @@ -56,7 +57,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -78,10 +79,11 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args() -devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) device = devices.device +weight_load_location = None if cmd_opts.lowram else "cpu" batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram @@ -184,6 +186,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), @@ -224,6 +227,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -242,11 +246,13 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), - "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -260,7 +266,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { @@ -288,6 +293,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 59b2b0219..68ceffe3f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,11 +24,12 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token + self.batch_size = batch_size self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) @@ -78,13 +79,14 @@ class PersonalizedBase(Dataset): if include_cond: entry.cond_text = self.create_text(filename_text) - entry.cond = cond_model([entry.cond_text]).to(devices.cpu) + entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) self.dataset.append(entry) - self.length = len(self.dataset) * repeats + assert len(self.dataset) > 1, "No images have been found in the dataset." + self.length = len(self.dataset) * repeats // batch_size - self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.initial_indexes = np.arange(len(self.dataset)) self.indexes = None self.shuffle() @@ -101,13 +103,19 @@ class PersonalizedBase(Dataset): return self.length def __getitem__(self, i): - if i % len(self.dataset) == 0: - self.shuffle() + res = [] - index = self.indexes[i % len(self.indexes)] - entry = self.dataset[index] + for j in range(self.batch_size): + position = i * self.batch_size + j + if position % len(self.indexes) == 0: + self.shuffle() - if entry.cond is None: - entry.cond_text = self.create_text(entry.filename_text) + index = self.indexes[position % len(self.indexes)] + entry = self.dataset[index] - return entry + if entry.cond is None: + entry.cond_text = self.create_text(entry.filename_text) + + res.append(entry) + + return res diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b12a8e6d8..f59b47a97 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,6 +6,7 @@ import torch import tqdm import html import datetime +import csv from PIL import Image, PngImagePlugin @@ -172,15 +173,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def batched(dataset, total, n=1): - for ndx in range(0, total, n): - yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + + epoch = step // epoch_len + epoch_step = step - epoch * epoch_len + + csv_writer.writerow({ + "step": step + 1, + "epoch": epoch + 1, + "epoch_step": epoch_step + 1, + **values, + }) -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, - create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, - preview_image_prompt, batch_size=1, - gradient_accumulation=1): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -212,11 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, - height=training_height, - repeats=shared.opts.training_image_repeats_per_epoch, - placeholder_token=embedding_name, model=shared.sd_model, - device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) hijack = sd_hijack.model_hijack @@ -235,8 +250,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step) - for i, entry in pbar: + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, entries in pbar: embedding.step = i + ititial_step scheduler.apply(optimizer, embedding.step) @@ -247,11 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([e.cond_text for e in entry]) - - x = torch.stack([e.latent for e in entry]).to(devices.device) + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] - del x losses[embedding.step % losses.shape[0]] = loss.item() @@ -271,21 +284,37 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') embedding.save(last_saved_file) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt - p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - height=training_height, - width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height + + preview_text = p.prompt + processed = processing.process_images(p) image = processed.images[0] @@ -320,7 +349,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry[-1].cond_text)}
+Last prompt: {html.escape(entries[0].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/txt2img.py b/modules/txt2img.py index eedcdfe02..8f394d05b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,18 +6,13 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, - restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, - subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, - height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, - aesthetic_lr=0, +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False, - *args): + aesthetic_text_negative=False, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -41,8 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: restore_faces=restore_faces, tiling=tiling, enable_hr=enable_hr, - scale_latent=scale_latent if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None, + firstphase_width=firstphase_width if enable_hr else None, + firstphase_height=firstphase_height if enable_hr else None, ) if cmd_opts.enable_console_prompts: diff --git a/modules/ui.py b/modules/ui.py index e98e21135..d0696101e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -7,6 +7,7 @@ import mimetypes import os import random import sys +import tempfile import time import traceback import platform @@ -22,7 +23,7 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack +from modules import sd_hijack, sd_models from modules.paths import script_path from modules.shared import opts, cmd_opts,aesthetic_embeddings @@ -41,7 +42,10 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui + import modules.aesthetic_clip +import modules.images_history as img_his + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -81,6 +85,8 @@ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 def plaintext_to_html(text): @@ -89,6 +95,14 @@ def plaintext_to_html(text): def image_from_url_text(filedata): + if type(filedata) == dict and filedata["is_file"]: + filename = filedata["name"] + tempdir = os.path.normpath(tempfile.gettempdir()) + normfn = os.path.normpath(filename) + assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory' + + return Image.open(filename) + if type(filedata) == list: if len(filedata) == 0: return None @@ -177,6 +191,23 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") +def save_pil_to_file(pil_image, dir=None): + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in pil_image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + + file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) + pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) + return file_obj + + +# override save to file function so that it also writes PNG info +gr.processing_utils.save_pil_to_file = save_pil_to_file + + def wrap_gradio_call(func, extra_outputs=None): def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled @@ -305,7 +336,7 @@ def visit(x, func, path=""): def add_style(name: str, prompt: str, negative_prompt: str): if name is None: - return [gr_show(), gr_show()] + return [gr_show() for x in range(4)] style = modules.styles.PromptStyle(name, prompt, negative_prompt) shared.prompt_styles.styles[style.name] = style @@ -430,30 +461,38 @@ def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" with gr.Row(elem_id="toprow"): - with gr.Column(scale=4): + with gr.Column(scale=6): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) - - with gr.Column(scale=1, elem_id="roll_col"): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - paste = gr.Button(value=paste_symbol, elem_id="paste") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - with gr.Column(scale=10, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) with gr.Row(): - with gr.Column(scale=8): + with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) - with gr.Column(scale=1, elem_id="roll_col"): - sh = gr.Button(elem_id="sh", visible=True) + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) + with gr.Column(scale=1, elem_id="roll_col"): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + + if cmd_opts.deepdanbooru: + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): with gr.Row(): @@ -473,20 +512,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(scale=1): - if is_img2img: - interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - if cmd_opts.deepdanbooru: - deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - else: - deepbooru = None - else: - interrogate = None - deepbooru = None - prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") - save_style = gr.Button('Create style', elem_id="style_create") + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -510,13 +543,40 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None): ) +def apply_setting(key, value): + if value is None: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data[key] + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): @@ -554,10 +614,11 @@ def create_ui(wrap_gradio_gpu_call): enable_hr = gr.Checkbox(label='Highres. fix', value=False) with gr.Row(visible=False) as hr_options: - scale_latent = gr.Checkbox(label='Scale latent', value=False) + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) - with gr.Row(): + with gr.Row(equal_height=True): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) @@ -600,33 +661,35 @@ def create_ui(wrap_gradio_gpu_call): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - scale_latent, - denoising_strength, - aesthetic_lr, - aesthetic_weight, - aesthetic_steps, - aesthetic_imgs, - aesthetic_slerp, - aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative - ] + custom_inputs, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + firstphase_width, + firstphase_height, + aesthetic_lr, + aesthetic_weight, + aesthetic_steps, + aesthetic_imgs, + aesthetic_slerp, + aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative + ] + custom_inputs, + outputs=[ txt2img_gallery, generation_info, @@ -638,6 +701,17 @@ def create_ui(wrap_gradio_gpu_call): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + enable_hr.change( fn=lambda x: gr_show(x), inputs=[enable_hr], @@ -690,14 +764,29 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (firstphase_width, "First pass size-1"), + (firstphase_height, "First pass size-2"), ] - modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + with gr.Column(scale=1): pass @@ -711,10 +800,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool) + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) with gr.TabItem('Inpaint', id='inpaint'): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") @@ -792,6 +881,17 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + mask_mode.change( lambda mode, img: { init_img_with_mask: gr_show(mode == 0), @@ -932,7 +1032,6 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), ] - modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as extras_interface: @@ -980,6 +1079,7 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else '' open_extras_folder = gr.Button('Open output directory', elem_id=button_id) + submit.click( fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", @@ -1039,6 +1139,14 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image], outputs=[html, generation_info, html2], ) + #images history + images_history_switch_dict = { + "fn":modules.generation_parameters_copypaste.connect_paste, + "t2i":txt2img_paste_fields, + "i2i":img2img_paste_fields + } + + images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1050,8 +1158,8 @@ def create_ui(wrap_gradio_gpu_call): secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1125,6 +1233,7 @@ def create_ui(wrap_gradio_gpu_call): train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) @@ -1137,7 +1246,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_image_prompt = gr.Textbox(label='Preview prompt', value="") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") @@ -1220,6 +1329,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ train_embedding_name, learn_rate, + batch_size, dataset_directory, log_directory, training_width, @@ -1229,9 +1339,8 @@ def create_ui(wrap_gradio_gpu_call): save_embedding_every, template_file, save_image_with_stored_embedding, - preview_image_prompt, - batch_size, - gradient_accumulation + preview_from_txt2img, + *txt2img_preview_params, ], outputs=[ ti_output, @@ -1245,13 +1354,15 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ train_hypernetwork_name, learn_rate, + batch_size, dataset_directory, log_directory, steps, create_image_every, save_embedding_every, template_file, - preview_image_prompt, + preview_from_txt2img, + *txt2img_preview_params, ], outputs=[ ti_output, @@ -1463,6 +1574,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), + (images_history, "History", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), @@ -1603,8 +1715,22 @@ Requested path was: {f} outputs=[extras_image], ) - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img') - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img') + settings_map = { + 'sd_hypernetwork': 'Hypernet', + 'CLIP_stop_at_last_layers': 'Clip skip', + 'sd_model_checkpoint': 'Model hash', + } + + settings_paste_fields = [ + (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None))) + for k, v in settings_map.items() + ] + + modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt) + modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt) + + modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img') + modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img') ui_config_file = cmd_opts.ui_config_file ui_settings = {} @@ -1686,3 +1812,4 @@ if 'gradio_routes_templates_response' not in globals(): gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response + diff --git a/requirements.txt b/requirements.txt index a0d985ce7..cf583de99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio==3.4.1 +gradio==3.5 invisible-watermark numpy omegaconf diff --git a/requirements_versions.txt b/requirements_versions.txt index 2bbea40b4..abadcb583 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -2,7 +2,7 @@ transformers==4.19.2 diffusers==0.3.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.4.1 +gradio==3.5 numpy==1.23.3 Pillow==9.2.0 realesrgan==0.3.0 diff --git a/script.js b/script.js index 9543cbe68..88f2c839d 100644 --- a/script.js +++ b/script.js @@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() { document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { - if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; + if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; } else if (e.keyCode !== undefined) { - if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; + if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; } if (handled) { button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b24f1a806..1266be6fa 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,7 +1,9 @@ +import copy import math import os import sys import traceback +import shlex import modules.scripts as scripts import gradio as gr @@ -10,6 +12,75 @@ from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state + +def process_string_tag(tag): + return tag + + +def process_int_tag(tag): + return int(tag) + + +def process_float_tag(tag): + return float(tag) + + +def process_boolean_tag(tag): + return True if (tag == "true") else False + + +prompt_tags = { + "sd_model": None, + "outpath_samples": process_string_tag, + "outpath_grids": process_string_tag, + "prompt_for_display": process_string_tag, + "prompt": process_string_tag, + "negative_prompt": process_string_tag, + "styles": process_string_tag, + "seed": process_int_tag, + "subseed_strength": process_float_tag, + "subseed": process_int_tag, + "seed_resize_from_h": process_int_tag, + "seed_resize_from_w": process_int_tag, + "sampler_index": process_int_tag, + "batch_size": process_int_tag, + "n_iter": process_int_tag, + "steps": process_int_tag, + "cfg_scale": process_float_tag, + "width": process_int_tag, + "height": process_int_tag, + "restore_faces": process_boolean_tag, + "tiling": process_boolean_tag, + "do_not_save_samples": process_boolean_tag, + "do_not_save_grid": process_boolean_tag +} + + +def cmdargs(line): + args = shlex.split(line) + pos = 0 + res = {} + + while pos < len(args): + arg = args[pos] + + assert arg.startswith("--"), f'must start with "--": {arg}' + tag = arg[2:] + + func = prompt_tags.get(tag, None) + assert func, f'unknown commandline option: {arg}' + + assert pos+1 < len(args), f'missing argument for command line option {arg}' + + val = args[pos+1] + + res[tag] = func(val) + + pos += 2 + + return res + + class Script(scripts.Script): def title(self): return "Prompts from file or textbox" @@ -32,26 +103,48 @@ class Script(scripts.Script): return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): - if (checkbox_txt): + if checkbox_txt: lines = [x.strip() for x in prompt_txt.splitlines()] else: lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")] lines = [x for x in lines if len(x) > 0] - img_count = len(lines) * p.n_iter - batch_count = math.ceil(img_count / p.batch_size) - loop_count = math.ceil(batch_count / p.n_iter) - print(f"Will process {img_count} images in {batch_count} batches.") - p.do_not_save_grid = True - state.job_count = batch_count + job_count = 0 + jobs = [] + + for line in lines: + if "--" in line: + try: + args = cmdargs(line) + except Exception: + print(f"Error parsing line [line] as commandline:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + args = {"prompt": line} + else: + args = {"prompt": line} + + n_iter = args.get("n_iter", 1) + if n_iter != 1: + job_count += n_iter + else: + job_count += 1 + + jobs.append(args) + + print(f"Will process {len(lines)} lines in {job_count} jobs.") + state.job_count = job_count images = [] - for loop_no in range(loop_count): - state.job = f"{loop_no + 1} out of {loop_count}" - p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter - proc = process_images(p) + for n, args in enumerate(jobs): + state.job = f"{state.job_no + 1} out of {state.job_count}" + + copy_p = copy.copy(p) + for k, v in args.items(): + setattr(copy_p, k, v) + + proc = process_images(copy_p) images += proc.images return Processed(p, images, p.seed, "") diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index efb63af54..88ad3bf78 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -12,7 +12,7 @@ import gradio as gr from modules import images from modules.hypernetworks import hypernetwork -from modules.processing import process_images, Processed, get_correct_sampler +from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -176,7 +176,7 @@ axis_options = [ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), ] @@ -338,7 +338,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed': + if axis_opt.label in ['Seed','Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list @@ -354,6 +354,9 @@ class Script(scripts.Script): else: total_steps = p.steps * len(xs) * len(ys) + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + total_steps *= 2 + print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") shared.total_tqdm.updateTotal(total_steps * p.n_iter) diff --git a/style.css b/style.css index aa3d379c1..b534f9500 100644 --- a/style.css +++ b/style.css @@ -115,7 +115,7 @@ padding: 0.4em 0; } -#roll, #paste{ +#roll, #paste, #style_create, #style_apply{ min-width: 2em; min-height: 2em; max-width: 2em; @@ -126,14 +126,14 @@ margin: 0.1em 0; } -#style_apply, #style_create, #interrogate{ - margin: 0.75em 0.25em 0.25em 0.25em; - min-width: 5em; +#interrogate_col{ + min-width: 0 !important; + max-width: 8em !important; } - -#style_apply, #style_create, #deepbooru{ - margin: 0.75em 0.25em 0.25em 0.25em; - min-width: 5em; +#interrogate, #deepbooru{ + margin: 0em 0.25em 0.9em 0.25em; + min-width: 8em; + max-width: 8em; } #style_pos_col, #style_neg_col{ @@ -167,18 +167,6 @@ button{ align-self: stretch !important; } -#prompt, #negative_prompt{ - border: none !important; -} -#prompt textarea, #negative_prompt textarea{ - border: none !important; -} - - -#img2maskimg .h-60{ - height: 30rem; -} - .overflow-hidden, .gr-panel{ overflow: visible !important; } @@ -451,10 +439,6 @@ input[type="range"]{ --tw-bg-opacity: 0 !important; } -#img2img_image div.h-60{ - height: 480px; -} - #context-menu{ z-index:9999; position:absolute; @@ -529,3 +513,11 @@ canvas[key="mask"] { .row.gr-compact{ overflow: visible; } + +#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, +img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img +{ + height: 480px !important; + max-height: 480px !important; + min-height: 480px !important; +} \ No newline at end of file diff --git a/webui.sh b/webui.sh index 05ca497d2..980c0aaf3 100755 --- a/webui.sh +++ b/webui.sh @@ -82,8 +82,8 @@ then clone_dir="${PWD##*/}" fi -# Check prequisites -for preq in git python3 +# Check prerequisites +for preq in "${GIT}" "${python_cmd}" do if ! hash "${preq}" &>/dev/null then