diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index be7ffa23b..9a0b8d22b 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -18,7 +18,7 @@ jobs: cache-dependency-path: | **/requirements*txt - name: Run tests - run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() diff --git a/README.md b/README.md index 24f8e7998..b67e2296a 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion. - Prompt Matrix - Stable Diffusion Upscale - Attention, specify parts of text that the model should pay more attention to - - a man in a ((tuxedo)) - will pay more attention to tuxedo - - a man in a (tuxedo:1.21) - alternative syntax - - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) + - a man in a `((tuxedo))` - will pay more attention to tuxedo + - a man in a `(tuxedo:1.21)` - alternative syntax + - select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user) - Loopback, run img2img processing multiple times - X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - Textual Inversion @@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion. - CodeFormer, face restoration tool as an alternative to GFPGAN - RealESRGAN, neural network upscaler - ESRGAN, neural network upscaler with a lot of third party models - - SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers + - SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers - LDSR, Latent diffusion super resolution upscaling - Resizing aspect ratio options - Sampling method selection @@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion. - drag and drop an image/text-parameters to promptbox - Read Generation Parameters Button, loads parameters in promptbox to UI - Settings page -- Running arbitrary python code from UI (must run with --allow-code to enable) +- Running arbitrary python code from UI (must run with `--allow-code` to enable) - Mouseover hints for most UI elements - Possible to change defaults/mix/max/step values for UI elements via text config - Tiling support, a checkbox to create images that can be tiled like textures @@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion. - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts -- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args) - via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI - Generate forever option - Training tab @@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion. - Clip skip - Hypernetworks - Loras (same as Hypernetworks but more pretty) -- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. +- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt - Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API -- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. +- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions @@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion. - Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen -- ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. @@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab): - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) ### Automatic Installation on Windows -1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" +1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH". 2. Install [git](https://git-scm.com/download/win). 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. 4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. @@ -159,4 +158,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. -- (You) +- (You) \ No newline at end of file diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 8937b585e..d3eb0d3bc 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,20 +2,34 @@ import glob import os import re import torch +from typing import Union from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} -def convert_diffusers_name_to_compvis(key): - def match(match_list, regex): +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + r = re.match(regex, key) if not r: return False @@ -26,16 +40,33 @@ def convert_diffusers_name_to_compvis(key): m = [] - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - if match(m, re_text_block): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key @@ -101,15 +132,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) - keys_failed_to_match = [] + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) - key, lora_key = fullkey.split(".", 1) + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: - keys_failed_to_match.append(key_diffusers) + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key continue lora_module = lora.modules.get(key, None) @@ -123,15 +161,21 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,28 +221,120 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - if len(loaded_loras) == 0: - return res +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of Loras to the weights of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ + + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + self.weight.copy_(weights_backup) - return res + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) + continue + + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate lora weights for layer {lora_layer_name}') + + setattr(self, "lora_current_names", wanted_names) + + +def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): @@ -211,7 +347,7 @@ def list_available_loras(): glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) - for filename in sorted(candidates): + for filename in sorted(candidates, key=str.lower): if os.path.isdir(filename): continue diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160e..0adab2254 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora def before_ui(): @@ -20,11 +24,27 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +53,4 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - })) diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 4a85c8ebf..f0918e260 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -89,22 +89,15 @@ function checkBrackets(evt, textArea, counterElt) { function setupBracketChecking(id_prompt, id_counter){ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); var counter = gradioApp().getElementById(id_counter) + textarea.addEventListener("input", function(evt){ checkBrackets(evt, textarea, counter) }); } -var shadowRootLoaded = setInterval(function() { - var shadowRoot = document.querySelector('gradio-app').shadowRoot; - if(! shadowRoot) return false; - - var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) return false; - - clearInterval(shadowRootLoaded); - +onUiLoaded(function(){ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') - setupBracketChecking('img2img_prompt', 'imgimg_token_counter') + setupBracketChecking('img2img_prompt', 'img2img_token_counter') setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') -}, 1000); +}) \ No newline at end of file diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 1bf3fc30d..ef4b613af 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,4 @@ -
+
{metadata_button}
diff --git a/html/licenses.html b/html/licenses.html index bddbf4665..bc995aa07 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -635,4 +635,30 @@ SOFTWARE. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + +

Curated transformers

+The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers +
+The MIT License (MIT)
+
+Copyright (C) 2021 ExplosionAI GmbH
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
 
\ No newline at end of file diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 0f164b82c..a8278cca2 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -12,7 +12,7 @@ function dimensionChange(e, is_width, is_height){ currentHeight = e.target.value*1.0 } - var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) + var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block"; if(!inImg2img){ return; @@ -22,7 +22,7 @@ function dimensionChange(e, is_width, is_height){ var tabIndex = get_tab_index('mode_img2img') if(tabIndex == 0){ // img2img - targetElement = gradioApp().querySelector('div[data-testid=image] img'); + targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img'); } else if(tabIndex == 1){ //Sketch targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); } else if(tabIndex == 2){ // Inpaint @@ -30,7 +30,7 @@ function dimensionChange(e, is_width, is_height){ } else if(tabIndex == 3){ // Inpaint sketch targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); } - + if(targetElement){ @@ -38,7 +38,7 @@ function dimensionChange(e, is_width, is_height){ if(!arPreviewRect){ arPreviewRect = document.createElement('div') arPreviewRect.id = "imageARPreview"; - gradioApp().getRootNode().appendChild(arPreviewRect) + gradioApp().appendChild(arPreviewRect) } @@ -91,23 +91,26 @@ onUiUpdate(function(){ if(arPreviewRect){ arPreviewRect.style.display = 'none'; } - var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) - if(inImg2img){ - let inputs = gradioApp().querySelectorAll('input'); - inputs.forEach(function(e){ - var is_width = e.parentElement.id == "img2img_width" - var is_height = e.parentElement.id == "img2img_height" + var tabImg2img = gradioApp().querySelector("#tab_img2img"); + if (tabImg2img) { + var inImg2img = tabImg2img.style.display == "block"; + if(inImg2img){ + let inputs = gradioApp().querySelectorAll('input'); + inputs.forEach(function(e){ + var is_width = e.parentElement.id == "img2img_width" + var is_height = e.parentElement.id == "img2img_height" - if((is_width || is_height) && !e.classList.contains('scrollwatch')){ - e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) - e.classList.add('scrollwatch') - } - if(is_width){ - currentWidth = e.value*1.0 - } - if(is_height){ - currentHeight = e.value*1.0 - } - }) - } + if((is_width || is_height) && !e.classList.contains('scrollwatch')){ + e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) + e.classList.add('scrollwatch') + } + if(is_width){ + currentWidth = e.value*1.0 + } + if(is_height){ + currentHeight = e.value*1.0 + } + }) + } + } }); diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 11bcce1bc..06f505b0d 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -43,7 +43,7 @@ contextMenuInit = function(){ }) - gradioApp().getRootNode().appendChild(contextMenu) + gradioApp().appendChild(contextMenu) let menuWidth = contextMenu.offsetWidth + 4; let menuHeight = contextMenu.offsetHeight + 4; diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index 619bb1fa3..20a5aadfb 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,6 +1,6 @@ function keyupEditAttention(event){ let target = event.originalTarget || event.composedPath()[0]; - if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; + if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return; if (! (event.metaKey || event.ctrlKey)) return; let isPlus = event.key == "ArrowUp" diff --git a/javascript/extensions.js b/javascript/extensions.js index c593cd2e5..72924a28c 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -1,5 +1,5 @@ -function extensions_apply(_, _){ +function extensions_apply(_, _, disable_all){ var disable = [] var update = [] @@ -13,10 +13,10 @@ function extensions_apply(_, _){ restart_reload() - return [JSON.stringify(disable), JSON.stringify(update)] + return [JSON.stringify(disable), JSON.stringify(update), disable_all] } -function extensions_check(){ +function extensions_check(_, _){ var disable = [] gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2fb87cd5b..253221389 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -139,3 +139,41 @@ function extraNetworksShowMetadata(text){ popup(elem); } + +function requestGet(url, data, handler, errorHandler){ + var xhr = new XMLHttpRequest(); + var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&') + xhr.open("GET", url + "?" + args, true); + + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + try { + var js = JSON.parse(xhr.responseText); + handler(js) + } catch (error) { + console.error(error); + errorHandler() + } + } else{ + errorHandler() + } + } + }; + var js = JSON.stringify(data); + xhr.send(js); +} + +function extraNetworksRequestMetadata(event, extraPage, cardName){ + showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); } + + requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){ + if(data && data.metadata){ + extraNetworksShowMetadata(data.metadata) + } else{ + showError() + } + }, showError) + + event.stopPropagation() +} diff --git a/javascript/hints.js b/javascript/hints.js index 7f4101b23..f48a0eb69 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -18,11 +18,10 @@ titles = { "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", - "\u{1f5d1}": "Clear prompt", + "\u{1f5d1}\ufe0f": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", - "\u{1f3b4}": "Show extra networks", - + "\u{1f3b4}": "Show/hide extra networks", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", @@ -40,8 +39,7 @@ titles = { "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", - "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", - + "Skip": "Stop processing current image and continue processing.", "Interrupt": "Stop processing images and return any results accumulated so far.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", @@ -71,8 +69,10 @@ titles = { "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime