Merge branch 'release_candidate'

This commit is contained in:
AUTOMATIC 2023-05-13 08:16:37 +03:00
commit b08500cec8
71 changed files with 865 additions and 422 deletions

View File

@ -1,3 +1,48 @@
## 1.2.0
### Features:
* do not wait for stable diffusion model to load at startup
* add filename patterns: [denoising]
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
* add version to infotext, footer and console output when starting
* add links to wiki for filename pattern settings
* add extended info for quicksettings setting and use multiselect input instead of a text field
### Minor:
* gradio bumped to 3.29.0
* torch bumped to 2.0.1
* --subpath option for gradio for use with reverse proxy
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
* possible frontend optimization: do not apply localizations if there are none
* Add extra `None` option for VAE in XYZ plot
* print error to console when batch processing in img2img fails
* create HTML for extra network pages only on demand
* allow directories starting with . to still list their models for lora, checkpoints, etc
* put infotext options into their own category in settings tab
* do not show licenses page when user selects Show all pages in settings
### Extensions:
* Tooltip localization support
* Add api method to get LoRA models with prompt
### Bug Fixes:
* re-add /docs endpoint
* fix gamepad navigation
* make the lightbox fullscreen image function properly
* fix squished thumbnails in extras tab
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
* fix webui showing the same image if you configure the generation to always save results into same file
* fix bug with upscalers not working properly
* Fix MPS on PyTorch 2.0.1, Intel Macs
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
* prevent Reload UI button/link from reloading the page when it's not yet ready
* fix prompts from file script failing to read contents from a drag/drop file
## 1.1.1 ## 1.1.1
### Bug Fixes: ### Bug Fixes:
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle

View File

@ -1,6 +1,7 @@
from modules import extra_networks, shared from modules import extra_networks, shared
import lora import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')

View File

@ -4,7 +4,7 @@ import re
import torch import torch
from typing import Union from typing import Union
from modules import shared, devices, sd_models, errors from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@ -93,6 +93,7 @@ class LoraOnDisk:
self.metadata = m self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule: class LoraModule:
@ -165,8 +166,10 @@ def load_lora(name, filename):
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention: elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue continue
@ -199,11 +202,11 @@ def load_loras(names, multipliers=None):
loaded_loras.clear() loaded_loras.clear()
loras_on_disk = [available_loras.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]): if any([x is None for x in loras_on_disk]):
list_available_loras() list_available_loras()
loras_on_disk = [available_loras.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names): for i, name in enumerate(names):
lora = already_loaded.get(name, None) lora = already_loaded.get(name, None)
@ -232,6 +235,8 @@ def lora_calc_updown(lora, module, target):
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else: else:
updown = up @ down updown = up @ down
@ -240,6 +245,19 @@ def lora_calc_updown(lora, module, target):
return updown return updown
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
return
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of Loras to the weights of torch layer self. Applies the currently selected set of Loras to the weights of torch layer self.
@ -264,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup self.lora_weights_backup = weights_backup
if current_names != wanted_names: if current_names != wanted_names:
if weights_backup is not None: lora_restore_weights_from_backup(self)
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
@ -300,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
setattr(self, "lora_current_names", wanted_names) setattr(self, "lora_current_names", wanted_names)
def lora_forward(module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_loras) == 0:
return original_forward(module, input)
input = devices.cond_cast_unet(input)
lora_restore_weights_from_backup(module)
lora_reset_cached_weight(module)
res = original_forward(module, input)
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
module.up.to(device=devices.device)
module.down.to(device=devices.device)
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ()) setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None) setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
lora_apply_weights(self) lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input) return torch.nn.Linear_forward_before_lora(self, input)
@ -318,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input): def lora_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
lora_apply_weights(self) lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input) return torch.nn.Conv2d_forward_before_lora(self, input)
@ -343,24 +392,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def list_available_loras(): def list_available_loras():
available_loras.clear() available_loras.clear()
available_lora_aliases.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates, key=str.lower): for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename): if os.path.isdir(filename):
continue continue
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename)
available_loras[name] = LoraOnDisk(name, filename) available_loras[name] = entry
available_lora_aliases[name] = entry
available_lora_aliases[entry.alias] = entry
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k, v in params.items():
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_lora_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
available_loras = {} available_loras = {}
available_lora_aliases = {}
loaded_loras = [] loaded_loras = []
list_available_loras() list_available_loras()

View File

@ -1,12 +1,12 @@
import torch import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI
import lora import lora
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
@ -49,8 +49,33 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
})) }))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: lora.LoraOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_loras(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
script_callbacks.on_app_started(api_loras)

View File

@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
} }

View File

@ -6,7 +6,7 @@
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term'>{search_term}</span> <span style="display:none" class='search_term{serach_only}'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span> <span class='description'>{description}</span>

View File

@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
var viewportOffset = targetElement.getBoundingClientRect(); var viewportOffset = targetElement.getBoundingClientRect();
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
scaledx = targetElement.naturalWidth*viewportscale var scaledx = targetElement.naturalWidth*viewportscale
scaledy = targetElement.naturalHeight*viewportscale var scaledy = targetElement.naturalHeight*viewportscale
cleintRectTop = (viewportOffset.top+window.scrollY) var cleintRectTop = (viewportOffset.top+window.scrollY)
cleintRectLeft = (viewportOffset.left+window.scrollX) var cleintRectLeft = (viewportOffset.left+window.scrollX)
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
viewRectTop = cleintRectCentreY-(scaledy/2) var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
viewRectLeft = cleintRectCentreX-(scaledx/2) var arscaledx = currentWidth*arscale
arRectWidth = scaledx var arscaledy = currentHeight*arscale
arRectHeight = scaledy
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) var arRectTop = cleintRectCentreY-(arscaledy/2)
arscaledx = currentWidth*arscale var arRectLeft = cleintRectCentreX-(arscaledx/2)
arscaledy = currentHeight*arscale var arRectWidth = arscaledx
var arRectHeight = arscaledy
arRectTop = cleintRectCentreY-(arscaledy/2)
arRectLeft = cleintRectCentreX-(arscaledx/2)
arRectWidth = arscaledx
arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px'; arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px'; arPreviewRect.style.left = arRectLeft+'px';

View File

@ -4,7 +4,7 @@ contextMenuInit = function(){
let menuSpecs = new Map(); let menuSpecs = new Map();
const uid = function(){ const uid = function(){
return Date.now().toString(36) + Math.random().toString(36).substr(2); return Date.now().toString(36) + Math.random().toString(36).substring(2);
} }
function showContextMenu(event,element,menuEntries){ function showContextMenu(event,element,menuEntries){
@ -16,8 +16,7 @@ contextMenuInit = function(){
oldMenu.remove() oldMenu.remove()
} }
let tabButton = uiCurrentTab let baseStyle = window.getComputedStyle(uiCurrentTab)
let baseStyle = window.getComputedStyle(tabButton)
const contextMenu = document.createElement('nav') const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu" contextMenu.id = "context-menu"
@ -36,7 +35,7 @@ contextMenuInit = function(){
menuEntries.forEach(function(entry){ menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a') let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name'] contextMenuEntry.innerHTML = entry['name']
contextMenuEntry.addEventListener("click", function(e) { contextMenuEntry.addEventListener("click", function() {
entry['func'](); entry['func']();
}) })
contextMenuList.append(contextMenuEntry); contextMenuList.append(contextMenuEntry);
@ -63,7 +62,7 @@ contextMenuInit = function(){
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetElementSelector) var currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){ if(!currentItems){
currentItems = [] currentItems = []
@ -79,7 +78,7 @@ contextMenuInit = function(){
} }
function removeContextMenuOption(uid){ function removeContextMenuOption(uid){
menuSpecs.forEach(function(v,k) { menuSpecs.forEach(function(v) {
let index = -1 let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){ if(index>=0){
@ -93,8 +92,7 @@ contextMenuInit = function(){
return; return;
} }
gradioApp().addEventListener("click", function(e) { gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0] if(! e.isTrusted){
if(source.id && source.id.indexOf('check_progress')>-1){
return return
} }
@ -112,7 +110,6 @@ contextMenuInit = function(){
if(e.composedPath()[0].matches(k)){ if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v) showContextMenu(e,e.composedPath()[0],v)
e.preventDefault() e.preventDefault()
return
} }
}) })
}); });

View File

@ -69,8 +69,8 @@ function keyupEditAttention(event){
event.preventDefault(); event.preventDefault();
closeCharacter = ')' var closeCharacter = ')'
delta = opts.keyedit_precision_attention var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){ if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>' closeCharacter = '>'
@ -91,8 +91,8 @@ function keyupEditAttention(event){
selectionEnd += 1; selectionEnd += 1;
} }
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;

View File

@ -1,14 +1,14 @@
function extensions_apply(_, _, disable_all){ function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = [] var disable = []
var update = [] var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked) if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7)) update.push(x.name.substring(7))
}) })
restart_reload() restart_reload()
@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
return [JSON.stringify(disable), JSON.stringify(update), disable_all] return [JSON.stringify(disable), JSON.stringify(update), disable_all]
} }
function extensions_check(_, _){ function extensions_check(){
var disable = [] var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substring(7))
}) })
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
@ -41,7 +41,7 @@ function install_extension_from_index(button, url){
button.disabled = "disabled" button.disabled = "disabled"
button.value = "Installing..." button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea') var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url textarea.value = url
updateInput(textarea) updateInput(textarea)

View File

@ -1,4 +1,3 @@
function setupExtraNetworksForTab(tabname){ function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
tabs.appendChild(search) tabs.appendChild(search)
tabs.appendChild(refresh) tabs.appendChild(refresh)
search.addEventListener("input", function(evt){ var applyFilter = function(){
searchTerm = search.value.toLowerCase() var searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() var searchOnly = elem.querySelector('.search_only')
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
var visible = text.indexOf(searchTerm) != -1
if(searchOnly && searchTerm.length < 4){
visible = false
}
elem.style.display = visible ? "" : "none"
}) })
}); }
search.addEventListener("input", applyFilter);
applyFilter();
extraNetworksApplyFilter[tabname] = applyFilter;
} }
function applyExtraNetworkFilter(tabname){
setTimeout(extraNetworksApplyFilter[tabname], 1);
}
var extraNetworksApplyFilter = {}
var activePromptTextarea = {}; var activePromptTextarea = {};
function setupExtraNetworks(){ function setupExtraNetworks(){
@ -55,7 +72,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var partToSearch = m[1] var partToSearch = m[1]
var replaced = false var replaced = false
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
m = found.match(re_extranet); m = found.match(re_extranet);
if(m[1] == partToSearch){ if(m[1] == partToSearch){
replaced = true; replaced = true;
@ -96,9 +113,9 @@ function saveCardPreview(event, tabname, filename){
} }
function extraNetworksSearchButton(tabs_id, event){ function extraNetworksSearchButton(tabs_id, event){
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
button = event.target var button = event.target
text = button.classList.contains("search-all") ? "" : button.textContent.trim() var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text searchTextarea.value = text
updateInput(searchTextarea) updateInput(searchTextarea)
@ -133,7 +150,7 @@ function popup(contents){
} }
function extraNetworksShowMetadata(text){ function extraNetworksShowMetadata(text){
elem = document.createElement('pre') var elem = document.createElement('pre')
elem.classList.add('popup-metadata'); elem.classList.add('popup-metadata');
elem.textContent = text; elem.textContent = text;
@ -165,7 +182,7 @@ function requestGet(url, data, handler, errorHandler){
} }
function extraNetworksRequestMetadata(event, extraPage, cardName){ function extraNetworksRequestMetadata(event, extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); } var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){ requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){ if(data && data.metadata){

View File

@ -23,7 +23,7 @@ let modalObserver = new MutationObserver(function(mutations) {
}); });
function attachGalleryListeners(tab_name) { function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery') var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => { gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow

View File

@ -66,8 +66,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
@ -118,16 +118,18 @@ titles = {
onUiUpdate(function(){ onUiUpdate(function(){
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
tooltip = titles[span.textContent]; if (span.title) return; // already has a title
if(!tooltip){ let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
tooltip = titles[span.value];
if(!tooltip){
tooltip = localization[titles[span.value]] || titles[span.value];
} }
if(!tooltip){ if(!tooltip){
for (const c of span.classList) { for (const c of span.classList) {
if (c in titles) { if (c in titles) {
tooltip = titles[c]; tooltip = localization[titles[c]] || titles[c];
break; break;
} }
} }
@ -142,7 +144,7 @@ onUiUpdate(function(){
if (select.onchange != null) return; if (select.onchange != null) return;
select.onchange = function(){ select.onchange = function(){
select.title = titles[select.value] || ""; select.title = localization[titles[select.value]] || titles[select.value] || "";
} }
}) })
}) })

View File

@ -1,16 +1,12 @@
function setInactive(elem, inactive){
if(inactive){
elem.classList.add('inactive')
} else{
elem.classList.remove('inactive')
}
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') function setInactive(elem, inactive){
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') elem.classList.toggle('inactive', !!inactive)
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') }
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""

View File

@ -2,11 +2,10 @@
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
* @see https://github.com/gradio-app/gradio/issues/1721 * @see https://github.com/gradio-app/gradio/issues/1721
*/ */
window.addEventListener( 'resize', () => imageMaskResize());
function imageMaskResize() { function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) { if ( ! canvases.length ) {
canvases_fixed = false; canvases_fixed = false; // TODO: this is unused..?
window.removeEventListener( 'resize', imageMaskResize ); window.removeEventListener( 'resize', imageMaskResize );
return; return;
} }
@ -15,7 +14,7 @@ function imageMaskResize() {
const previewImage = wrapper.previousElementSibling; const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) { if ( ! previewImage.complete ) {
previewImage.addEventListener( 'load', () => imageMaskResize()); previewImage.addEventListener( 'load', imageMaskResize);
return; return;
} }
@ -24,7 +23,6 @@ function imageMaskResize() {
const nw = previewImage.naturalWidth; const nw = previewImage.naturalWidth;
const nh = previewImage.naturalHeight; const nh = previewImage.naturalHeight;
const portrait = nh > nw; const portrait = nh > nw;
const factor = portrait;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
@ -40,6 +38,7 @@ function imageMaskResize() {
c.style.maxHeight = '100%'; c.style.maxHeight = '100%';
c.style.objectFit = 'contain'; c.style.objectFit = 'contain';
}); });
} }
onUiUpdate(() => imageMaskResize()); onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize);

View File

@ -1,7 +1,6 @@
window.onload = (function(){ window.onload = (function(){
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder.indexOf("Prompt") == -1) return; if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";

View File

@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
}) })
if (result != -1) { if (result != -1) {
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
nextButton.click() nextButton.click()
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
} }
function modalZoomSet(modalImage, enable) { function modalZoomSet(modalImage, enable) {
if (enable) { if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
modalImage.classList.add('modalImageFullscreen');
} else {
modalImage.classList.remove('modalImageFullscreen');
}
} }
function modalZoomToggle(event) { function modalZoomToggle(event) {
modalImage = gradioApp().getElementById("modalImage"); var modalImage = gradioApp().getElementById("modalImage");
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
event.stopPropagation() event.stopPropagation()
} }
@ -179,7 +175,7 @@ function galleryImageHandler(e) {
} }
onUiUpdate(function() { onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox); fullImg_preview.forEach(setupImageForLightbox);
} }

View File

@ -1,36 +1,57 @@
let delay = 350//ms window.addEventListener('gamepadconnected', (e) => {
window.addEventListener('gamepadconnected', (e) => { const index = e.gamepad.index;
console.log("Gamepad connected!") let isWaiting = false;
const gamepad = e.gamepad; setInterval(async () => {
setInterval(() => { if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const xValue = gamepad.axes[0].toFixed(2); const gamepad = navigator.getGamepads()[index];
if (xValue < -0.3) { const xValue = gamepad.axes[0];
modalPrevImage(e); if (xValue <= -0.3) {
} else if (xValue > 0.3) {
modalNextImage(e);
}
}, delay);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e); modalPrevImage(e);
} else if (e.deltaX >= 0.6) { isWaiting = true;
} else if (xValue >= 0.3) {
modalNextImage(e); modalNextImage(e);
isWaiting = true;
} }
if (isWaiting) {
await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0]
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
}, opts.js_modal_lightbox_gamepad_repeat);
isWaiting = false;
}
}, 10);
});
setTimeout(() => { /*
isScrolling = false; Primarily for vr controller type pointer devices.
}, delay); I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e);
} else if (e.deltaX >= 0.6) {
modalNextImage(e);
}
setTimeout(() => {
isScrolling = false;
}, opts.js_modal_lightbox_gamepad_repeat);
});
function sleepUntil(f, timeout) {
return new Promise((resolve) => {
const timeStart = new Date();
const wait = setInterval(function() {
if (f() || new Date() - timeStart > timeout) {
clearInterval(wait);
resolve();
}
}, 20);
}); });
}

View File

@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
original_lines = {} original_lines = {}
translated_lines = {} translated_lines = {}
function hasLocalization() {
return window.localization && Object.keys(window.localization).length > 0;
}
function textNodesUnder(el){ function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n); while(n=walk.nextNode()) a.push(n);
@ -35,11 +39,11 @@ function canBeTranslated(node, text){
if(! text) return false; if(! text) return false;
if(! node.parentElement) return false; if(! node.parentElement) return false;
parentType = node.parentElement.nodeName var parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){ if (parentType=='OPTION' || parentType=='SPAN'){
pnode = node var pnode = node
for(var level=0; level<4; level++){ for(var level=0; level<4; level++){
pnode = pnode.parentElement pnode = pnode.parentElement
if(! pnode) break; if(! pnode) break;
@ -69,7 +73,7 @@ function getTranslation(text){
} }
function processTextNode(node){ function processTextNode(node){
text = node.textContent.trim() var text = node.textContent.trim()
if(! canBeTranslated(node, text)) return if(! canBeTranslated(node, text)) return
@ -105,7 +109,7 @@ function processNode(node){
} }
function dumpTranslations(){ function dumpTranslations(){
dumped = {} var dumped = {}
if (localization.rtl) { if (localization.rtl) {
dumped.rtl = true dumped.rtl = true
} }
@ -119,39 +123,8 @@ function dumpTranslations(){
return dumped return dumped
} }
onUiUpdate(function(m){
m.forEach(function(mutation){
mutation.addedNodes.forEach(function(node){
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
mutations.forEach(mutation => {
mutation.addedNodes.forEach(node => {
if (node.tagName === 'STYLE') {
observer.disconnect();
for (const x of node.sheet.rules) { // find all rtl media rules
if (Array.from(x.media || []).includes('rtl')) {
x.media.appendMedium('all'); // enable them
}
}
}
})
});
})).observe(gradioApp(), { childList: true });
}
})
function download_localization() { function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4) var text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a'); var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
@ -163,3 +136,36 @@ function download_localization() {
document.body.removeChild(element); document.body.removeChild(element);
} }
if(hasLocalization()) {
onUiUpdate(function (m) {
m.forEach(function (mutation) {
mutation.addedNodes.forEach(function (node) {
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function () {
processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
mutations.forEach(mutation => {
mutation.addedNodes.forEach(node => {
if (node.tagName === 'STYLE') {
observer.disconnect();
for (const x of node.sheet.rules) { // find all rtl media rules
if (Array.from(x.media || []).includes('rtl')) {
x.media.appendMedium('all'); // enable them
}
}
}
})
});
})).observe(gradioApp(), { childList: true });
}
})
}

View File

@ -2,15 +2,15 @@
let lastHeadImg = null; let lastHeadImg = null;
notificationButton = null let notificationButton = null;
onUiUpdate(function(){ onUiUpdate(function(){
if(notificationButton == null){ if(notificationButton == null){
notificationButton = gradioApp().getElementById('request_notifications') notificationButton = gradioApp().getElementById('request_notifications')
if(notificationButton != null){ if(notificationButton != null){
notificationButton.addEventListener('click', function (evt) { notificationButton.addEventListener('click', () => {
Notification.requestPermission(); void Notification.requestPermission();
},true); },true);
} }
} }

View File

@ -1,16 +1,15 @@
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(id_gallery){ function rememberGallerySelection(){
} }
function getGallerySelectedIndex(id_gallery){ function getGallerySelectedIndex(){
} }
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
var url = url;
xhr.open("POST", url, true); xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json"); xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () { xhr.onreadystatechange = function () {
@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.style.width = rect.width + "px"; divProgress.style.width = rect.width + "px";
} }
progressText = "" let progressText = ""
divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.width = ((res.progress || 0) * 100.0) + '%'
divInner.style.background = res.progress ? "" : "transparent" divInner.style.background = res.progress ? "" : "transparent"

View File

@ -1,7 +1,7 @@
// various functions for interaction with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href var gradioURL = window.location.href
if (!gradioURL.includes('?__theme=')) { if (!gradioURL.includes('?__theme=')) {
window.location.replace(gradioURL + '?__theme=' + theme); window.location.replace(gradioURL + '?__theme=' + theme);
} }
@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
return [gallery[0]]; return [gallery[0]];
} }
index = selected_gallery_index() var index = selected_gallery_index()
if (index < 0 || index >= gallery.length){ if (index < 0 || index >= gallery.length){
// Use the first image in the gallery as the default // Use the first image in the gallery as the default
@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
} }
function args_to_array(args){ function args_to_array(args){
res = [] var res = []
for(var i=0;i<args.length;i++){ for(var i=0;i<args.length;i++){
res.push(args[i]) res.push(args[i])
} }
@ -138,7 +138,7 @@ function get_img2img_tab_index() {
} }
function create_submit_args(args){ function create_submit_args(args){
res = [] var res = []
for(var i=0;i<args.length;i++){ for(var i=0;i<args.length;i++){
res.push(args[i]) res.push(args[i])
} }
@ -160,7 +160,7 @@ function showSubmitButtons(tabname, show){
} }
function showRestoreProgressButton(tabname, show){ function showRestoreProgressButton(tabname, show){
button = gradioApp().getElementById(tabname + "_restore_progress") var button = gradioApp().getElementById(tabname + "_restore_progress")
if(! button) return if(! button) return
button.style.display = show ? "flex" : "none" button.style.display = show ? "flex" : "none"
@ -207,8 +207,9 @@ function submit_img2img(){
return res return res
} }
function restoreProgressTxt2img(x){ function restoreProgressTxt2img(){
showRestoreProgressButton("txt2img", false) showRestoreProgressButton("txt2img", false)
var id = localStorage.getItem("txt2img_task_id")
id = localStorage.getItem("txt2img_task_id") id = localStorage.getItem("txt2img_task_id")
@ -220,10 +221,11 @@ function restoreProgressTxt2img(x){
return id return id
} }
function restoreProgressImg2img(x){
function restoreProgressImg2img(){
showRestoreProgressButton("img2img", false) showRestoreProgressButton("img2img", false)
id = localStorage.getItem("img2img_task_id") var id = localStorage.getItem("img2img_task_id")
if(id) { if(id) {
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
@ -252,7 +254,7 @@ function modelmerger(){
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') var name_ = prompt('Style name:')
return [name_, prompt_text, negative_prompt_text] return [name_, prompt_text, negative_prompt_text]
} }
@ -287,11 +289,11 @@ function recalculate_prompts_img2img(){
} }
opts = {} var opts = {}
onUiUpdate(function(){ onUiUpdate(function(){
if(Object.keys(opts).length != 0) return; if(Object.keys(opts).length != 0) return;
json_elem = gradioApp().getElementById('settings_json') var json_elem = gradioApp().getElementById('settings_json')
if(json_elem == null) return; if(json_elem == null) return;
var textarea = json_elem.querySelector('textarea') var textarea = json_elem.querySelector('textarea')
@ -340,12 +342,15 @@ onUiUpdate(function(){
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button') registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button') registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
show_all_pages = gradioApp().getElementById('settings_show_all_pages') var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div') var settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){ if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages) settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){ show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){ gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
if(elem.id == "settings_tab_licenses")
return;
elem.style.display = "block"; elem.style.display = "block";
}) })
} }
@ -353,9 +358,9 @@ onUiUpdate(function(){
}) })
onOptionsChanged(function(){ onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash') var elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || "" var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
shorthash = sd_checkpoint_hash.substr(0,10) var shorthash = sd_checkpoint_hash.substring(0,10)
if(elem && elem.textContent != shorthash){ if(elem && elem.textContent != shorthash){
elem.textContent = shorthash elem.textContent = shorthash
@ -390,7 +395,16 @@ function update_token_counter(button_id) {
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
var requestPing = function(){
requestGet("./internal/ping", {}, function(data){
location.reload();
}, function(){
setTimeout(requestPing, 500);
})
}
setTimeout(requestPing, 2000);
return [] return []
} }

View File

@ -0,0 +1,41 @@
// various hints and extra info for the settings tab
onUiLoaded(function(){
createLink = function(elem_id, text, href){
var a = document.createElement('A')
a.textContent = text
a.target = '_blank';
elem = gradioApp().querySelector('#'+elem_id)
elem.insertBefore(a, elem.querySelector('label'))
return a
}
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
requestGet("./internal/quicksettings-hint", {}, function(data){
var table = document.createElement('table')
table.className = 'settings-value-table'
data.forEach(function(obj){
var tr = document.createElement('tr')
var td = document.createElement('td')
td.textContent = obj.name
tr.appendChild(td)
var td = document.createElement('td')
td.textContent = obj.label
tr.appendChild(td)
table.appendChild(tr)
})
popup(table);
})
});
})

View File

@ -19,6 +19,7 @@ python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None stored_commit_hash = None
stored_git_tag = None
dir_repos = "repositories" dir_repos = "repositories"
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
@ -70,6 +71,20 @@ def commit_hash():
return stored_commit_hash return stored_commit_hash
def git_tag():
global stored_git_tag
if stored_git_tag is not None:
return stored_git_tag
try:
stored_git_tag = run(f"{git} describe --tags").strip()
except Exception:
stored_git_tag = "<none>"
return stored_git_tag
def run(command, desc=None, errdesc=None, custom_env=None, live=False): def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None: if desc is not None:
print(desc) print(desc)
@ -222,7 +237,7 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
@ -246,8 +261,10 @@ def prepare_environment():
check_python_version() check_python_version()
commit = commit_hash() commit = commit_hash()
tag = git_tag()
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Version: {tag}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):

View File

@ -570,20 +570,20 @@ class Api:
filename = create_embedding(**args) # create empty embedding filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end() shared.state.end()
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) return CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return TrainResponse(info = "create embedding error: {error}".format(error = e)) return TrainResponse(info=f"create embedding error: {e}")
def create_hypernetwork(self, args: dict): def create_hypernetwork(self, args: dict):
try: try:
shared.state.begin() shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding filename = create_hypernetwork(**args) # create empty embedding
shared.state.end() shared.state.end()
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) return CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) return TrainResponse(info=f"create hypernetwork error: {e}")
def preprocess(self, args: dict): def preprocess(self, args: dict):
try: try:
@ -593,13 +593,13 @@ class Api:
return PreprocessResponse(info = 'preprocess complete') return PreprocessResponse(info = 'preprocess complete')
except KeyError as e: except KeyError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) return PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e: except FileNotFoundError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) return PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
@ -617,10 +617,10 @@ class Api:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except AssertionError as msg:
shared.state.end() shared.state.end()
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) return TrainResponse(info=f"train embedding error: {msg}")
def train_hypernetwork(self, args: dict): def train_hypernetwork(self, args: dict):
try: try:
@ -641,10 +641,10 @@ class Api:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except AssertionError as msg:
shared.state.end() shared.state.end()
return TrainResponse(info="train embedding error: {error}".format(error=error)) return TrainResponse(info=f"train embedding error: {error}")
def get_memory(self): def get_memory(self):
try: try:

View File

@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
max_debug_str_len = 131072 # (1024*1024)/8 max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr) print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}" argStr = f"Arguments: {args} {kwargs}"
print(argStr[:max_debug_str_len], file=sys.stderr) print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len: if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
if extra_outputs_array is None: if extra_outputs_array is None:
extra_outputs_array = [None, ''] extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"] error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False

View File

@ -102,3 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')

View File

@ -156,13 +156,16 @@ class UpscalerESRGAN(Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, filename = load_file_from_url(
file_name="%s.pth" % self.model_name, url=self.model_url,
progress=True) model_dir=self.model_path,
file_name=f"{self.model_name}.pth",
progress=True,
)
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None: if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename)) print(f"Unable to load {self.model_path} from {filename}")
return None return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

View File

@ -38,7 +38,7 @@ class RRDBNet(nn.Module):
elif upsample_mode == 'pixelshuffle': elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block upsample_block = pixelshuffle_block
else: else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
if upscale == 3: if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else: else:
@ -261,10 +261,10 @@ class Upsample(nn.Module):
def extra_repr(self): def extra_repr(self):
if self.scale_factor is not None: if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor) info = f'scale_factor={self.scale_factor}'
else: else:
info = 'size=' + str(self.size) info = f'size={self.size}'
info += ', mode=' + self.mode info += f', mode={self.mode}'
return info return info
@ -350,7 +350,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
elif act_type == 'sigmoid': # [0, 1] range output elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid() layer = nn.Sigmoid()
else: else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) raise NotImplementedError(f'activation layer [{act_type}] is not found')
return layer return layer
@ -372,7 +372,7 @@ def norm(norm_type, nc):
elif norm_type == 'none': elif norm_type == 'none':
def norm_layer(x): return Identity() def norm_layer(x): return Identity()
else: else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return layer return layer
@ -388,7 +388,7 @@ def pad(pad_type, padding):
elif pad_type == 'zero': elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding) layer = nn.ZeroPad2d(padding)
else: else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
return layer return layer
@ -432,7 +432,7 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False): spectral_norm=False):
""" Conv layer with padding, normalization, activation """ """ Conv layer with padding, normalization, activation """
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
padding = get_valid_padding(kernel_size, dilation) padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0 padding = padding if pad_type == 'zero' else 0

View File

@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []

View File

@ -59,6 +59,7 @@ def image_from_url_text(filedata):
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
filename = filename.rsplit('?', 1)[0]
return Image.open(filename) return Image.open(filename)
if type(filedata) == list: if type(filedata) == list:
@ -129,6 +130,7 @@ def connect_paste_params_buttons():
_js=jsfunc, _js=jsfunc,
inputs=[binding.source_image_component], inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
show_progress=False,
) )
if binding.source_text_component is not None and fields is not None: if binding.source_text_component is not None and fields is not None:
@ -140,6 +142,7 @@ def connect_paste_params_buttons():
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
show_progress=False,
) )
binding.paste_button.click( binding.paste_button.click(
@ -147,6 +150,7 @@ def connect_paste_params_buttons():
_js=f"switch_to_{binding.tabname}", _js=f"switch_to_{binding.tabname}",
inputs=None, inputs=None,
outputs=None, outputs=None,
show_progress=False,
) )
@ -265,8 +269,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v) m = re_imagesize.match(v)
if m is not None: if m is not None:
res[k+"-1"] = m.group(1) res[f"{k}-1"] = m.group(1)
res[k+"-2"] = m.group(2) res[f"{k}-2"] = m.group(2)
else: else:
res[k] = v res[k] = v
@ -409,12 +413,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
fn=paste_func, fn=paste_func,
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
show_progress=False,
) )
button.click( button.click(
fn=None, fn=None,
_js=f"recalculate_prompts_{tabname}", _js=f"recalculate_prompts_{tabname}",
inputs=[], inputs=[],
outputs=[], outputs=[],
show_progress=False,
) )

View File

@ -13,7 +13,7 @@ cache_data = None
def dump_cache(): def dump_cache():
with filelock.FileLock(cache_filename+".lock"): with filelock.FileLock(f"{cache_filename}.lock"):
with open(cache_filename, "w", encoding="utf8") as file: with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4) json.dump(cache_data, file, indent=4)
@ -22,7 +22,7 @@ def cache(subsection):
global cache_data global cache_data
if cache_data is None: if cache_data is None:
with filelock.FileLock(cache_filename+".lock"): with filelock.FileLock(f"{cache_filename}.lock"):
if not os.path.isfile(cache_filename): if not os.path.isfile(cache_filename):
cache_data = {} cache_data = {}
else: else:

View File

@ -357,6 +357,7 @@ class FilenameGenerator:
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
@ -466,7 +467,7 @@ def get_next_sequence_number(path, basename):
""" """
result = -1 result = -1
if basename != '': if basename != '':
basename = basename + "-" basename = f"{basename}-"
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
@ -535,7 +536,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
add_number = opts.save_images_add_number or file_decoration == '' add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number: if file_decoration != "" and add_number:
file_decoration = "-" + file_decoration file_decoration = f"-{file_decoration}"
file_decoration = namegen.apply(file_decoration) + suffix file_decoration = namegen.apply(file_decoration) + suffix
@ -565,7 +566,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
def _atomically_save_image(image_to_save, filename_without_extension, extension): def _atomically_save_image(image_to_save, filename_without_extension, extension):
# save image with .tmp extension to avoid race condition when another process detects new image in the directory # save image with .tmp extension to avoid race condition when another process detects new image in the directory
temp_file_path = filename_without_extension + ".tmp" temp_file_path = f"{filename_without_extension}.tmp"
image_format = Image.registered_extensions()[extension] image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png': if extension.lower() == '.png':
@ -625,7 +626,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt" txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file: with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n") file.write(f"{info}\n")
else: else:
txt_fullfn = None txt_fullfn = None

View File

@ -48,7 +48,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
try: try:
img = Image.open(image) img = Image.open(image)
except UnidentifiedImageError: except UnidentifiedImageError as e:
print(e)
continue continue
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)

View File

@ -28,7 +28,7 @@ def category_types():
def download_default_clip_interrogate_categories(content_dir): def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...") print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp" tmpdir = f"{content_dir}_tmp"
category_types = ["artists", "flavors", "mediums", "movements"] category_types = ["artists", "flavors", "mediums", "movements"]
try: try:
@ -214,7 +214,7 @@ class InterrogateModels:
if shared.opts.interrogate_return_ranks: if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})" res += f", ({match}:{score/100:.3f})"
else: else:
res += ", " + match res += f", {match}"
except Exception: except Exception:
print("Error interrogating", file=sys.stderr) print("Error interrogating", file=sys.stderr)

View File

@ -54,6 +54,11 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6) CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

View File

@ -22,9 +22,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
""" """
output = [] output = []
if ext_filter is None:
ext_filter = []
try: try:
places = [] places = []
@ -39,22 +36,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(model_path) places.append(model_path)
for place in places: for place in places:
if os.path.exists(place): for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
for file in glob.iglob(place + '**/**', recursive=True): if os.path.islink(full_path) and not os.path.exists(full_path):
full_path = file print(f"Skipping broken symlink: {full_path}")
if os.path.isdir(full_path): continue
continue if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
if os.path.islink(full_path) and not os.path.exists(full_path): continue
print(f"Skipping broken symlink: {full_path}") if full_path not in output:
continue output.append(full_path)
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
@ -133,12 +122,9 @@ forbidden_upscaler_classes = set()
def list_builtin_upscalers(): def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear() builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__()) builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers(): def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes: if cls not in builtin_upscaler_classes:

View File

@ -223,7 +223,7 @@ class DDPM(pl.LightningModule):
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print(f"Deleting key {k} from state_dict.")
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
@ -386,7 +386,7 @@ class DDPM(pl.LightningModule):
_, loss_dict_no_ema = self.shared_step(batch) _, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope(): with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch) _, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

View File

@ -94,7 +94,7 @@ class NoiseScheduleVP:
""" """
if schedule not in ['discrete', 'linear', 'cosine']: if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
self.schedule = schedule self.schedule = schedule
if schedule == 'discrete': if schedule == 'discrete':
@ -469,7 +469,7 @@ class UniPC:
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t return t
else: else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
""" """

View File

@ -7,8 +7,8 @@ def connect(token, port, region):
else: else:
if ':' in token: if ':' in token:
# token = authtoken:username:password # token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1] token, username, password = token.split(':', 2)
token = token.split(':')[0] account = f"{username}:{password}"
config = conf.PyngrokConfig( config = conf.PyngrokConfig(
auth_token=token, region=region auth_token=token, region=region

View File

@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)
break break
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),

View File

@ -458,6 +458,16 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed) p.subseed = get_fixed_seed(p.subseed)
def program_version():
import launch
res = launch.git_tag()
if res == "<none>":
res = None
return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
@ -483,13 +493,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Init image hash": getattr(p, 'init_img_hash', None), "Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None, "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
"Version": program_version() if opts.add_version_to_infotext else None,
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@ -769,7 +780,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(
p,
images_list=output_images,
seed=p.all_seeds[0],
info=infotext(),
comments="".join(f"\n\n{comment}" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)

View File

@ -96,7 +96,8 @@ def progressapi(req: ProgressRequest):
if image is not None: if image is not None:
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="png") image.save(buffered, format="png")
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
live_preview = f"data:image/png;base64,{base64_image}"
id_live_preview = shared.state.id_live_preview id_live_preview = shared.state.id_live_preview
else: else:
live_preview = None live_preview = None

View File

@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler):
for scaler in scalers: for scaler in scalers:
if scaler.local_data_path.startswith("http"): if scaler.local_data_path.startswith("http"):
filename = modelloader.friendly_name(scaler.local_data_path) filename = modelloader.friendly_name(scaler.local_data_path)
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None) local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
if local: if local_model_candidates:
scaler.local_data_path = local scaler.local_data_path = local_model_candidates[0]
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
info = self.load_model(path) info = self.load_model(path)
if not os.path.exists(info.local_data_path): if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name) print(f"Unable to load RealESRGAN model: {info.name}")
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(

View File

@ -163,7 +163,8 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False) need_tabname = self.show(True) == self.show(False)
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" tabkind = 'img2img' if self.is_img2img else 'txt2txt'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{tabname}{title}_{item_id}' return f'script_{tabname}{title}_{item_id}'
@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp):
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
""" """
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])] comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False): if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect') comp.elem_classes.append('multiselect')

View File

@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers) return self.process_tokens(remade_batch_tokens, batch_multipliers)

View File

@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
if q.device.type == 'mps':
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k = q.float(), k.float()

View File

@ -18,7 +18,7 @@ class TorchHijackForUnet:
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def cat(self, tensors, *args, **kwargs): def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2: if len(tensors) == 2:

View File

@ -2,6 +2,8 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import threading
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
@ -45,7 +47,7 @@ class CheckpointInfo:
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename) self.hash = model_hash(filename)
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
@ -67,7 +69,7 @@ class CheckpointInfo:
checkpoint_alisases[id] = self checkpoint_alisases[id] = self
def calculate_shorthash(self): def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None: if self.sha256 is None:
return return
@ -404,13 +406,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
class SdModelData:
def __init__(self):
self.sd_model = None
self.lock = threading.Lock()
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
return self.sd_model
def set_sd_model(self, v):
self.sd_model = v
model_data = SdModelData()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if shared.sd_model: if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
shared.sd_model = None model_data.sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
timer.record("hijack") timer.record("hijack")
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = model_data.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model return model_data.sd_model
try: try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model return sd_model
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
timer = Timer() timer = Timer()
if shared.sd_model: if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
# shared.sd_model.cond_stage_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
# shared.sd_model.first_stage_model.to(devices.cpu) model_data.sd_model = None
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()

View File

@ -111,7 +111,7 @@ def find_checkpoint_config_near_filename(info):
if info is None: if info is None:
return None return None
config = os.path.splitext(info.filename)[0] + ".yaml" config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config): if os.path.exists(config):
return config return config

View File

@ -198,7 +198,7 @@ class TorchHijack:
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x): def randn_like(self, x):
if self.sampler_noises: if self.sampler_noises:

View File

@ -89,7 +89,7 @@ def refresh_vae_list():
def find_vae_near_checkpoint(checkpoint_file): def find_vae_near_checkpoint(checkpoint_file):
checkpoint_path = os.path.splitext(checkpoint_file)[0] checkpoint_path = os.path.splitext(checkpoint_file)[0]
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
if os.path.isfile(vae_location): if os.path.isfile(vae_location):
return vae_location return vae_location

View File

@ -16,6 +16,7 @@ import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None demo = None
@ -391,21 +392,20 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
@ -413,6 +413,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}) "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
})) }))
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
}))
options_templates.update(options_section(('ui', "Live previews"), { options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
@ -542,6 +549,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file) self.data = json.load(file)
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
bad_settings = 0 bad_settings = 0
for k, v in self.data.items(): for k, v in self.data.items():
info = self.data_labels.get(k, None) info = self.data_labels.get(k, None)
@ -600,13 +611,37 @@ class Options:
return value return value
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" """assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent" latent_upscale_default_mode = "Latent"
latent_upscale_modes = { latent_upscale_modes = {
@ -620,8 +655,6 @@ latent_upscale_modes = {
sd_upscalers = [] sd_upscalers = []
sd_model = None
clip_model = None clip_model = None
progress_print_out = sys.stdout progress_print_out = sys.stdout
@ -639,8 +672,8 @@ def reload_gradio_theme(theme_name=None):
else: else:
try: try:
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
except requests.exceptions.ConnectionError: except Exception as e:
print("Can't access HuggingFace Hub, falling back to default Gradio theme") errors.display(e, "changing gradio theme")
gradio_theme = gr.themes.Default() gradio_theme = gr.themes.Default()
@ -701,3 +734,20 @@ def html(filename):
return file.read() return file.read()
return "" return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
for root, dirs, files in os.walk(path):
for filename in files:
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
yield os.path.join(root, filename)

View File

@ -74,7 +74,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Always keep a backup file around # Always keep a backup file around
if os.path.exists(path): if os.path.exists(path):
shutil.copy(path, path + ".bak") shutil.copy(path, f"{path}.bak")
fd = os.open(path, os.O_RDWR|os.O_CREAT) fd = os.open(path, os.O_RDWR|os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:

View File

@ -111,7 +111,7 @@ def focal_point(im, settings):
if corner_centroid is not None: if corner_centroid is not None:
color = BLUE color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight) box = corner_centroid.bounding(max_size * corner_centroid.weight)
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(corner_points) > 1: if len(corner_points) > 1:
for f in corner_points: for f in corner_points:
@ -119,7 +119,7 @@ def focal_point(im, settings):
if entropy_centroid is not None: if entropy_centroid is not None:
color = "#ff0" color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight) box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(entropy_points) > 1: if len(entropy_points) > 1:
for f in entropy_points: for f in entropy_points:
@ -127,7 +127,7 @@ def focal_point(im, settings):
if face_centroid is not None: if face_centroid is not None:
color = RED color = RED
box = face_centroid.bounding(max_size * face_centroid.weight) box = face_centroid.bounding(max_size * face_centroid.weight)
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(face_points) > 1: if len(face_points) > 1:
for f in face_points: for f in face_points:

View File

@ -72,7 +72,7 @@ class PersonalizedBase(Dataset):
except Exception: except Exception:
continue continue
text_filename = os.path.splitext(path)[0] + ".txt" text_filename = f"{os.path.splitext(path)[0]}.txt"
filename = os.path.basename(path) filename = os.path.basename(path)
if os.path.exists(text_filename): if os.path.exists(text_filename):

View File

@ -63,9 +63,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
image.save(os.path.join(params.dstdir, f"{basename}.png")) image.save(os.path.join(params.dstdir, f"{basename}.png"))
if params.preprocess_txt_action == 'prepend' and existing_caption: if params.preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption caption = f"{existing_caption} {caption}"
elif params.preprocess_txt_action == 'append' and existing_caption: elif params.preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption caption = f"{caption} {existing_caption}"
elif params.preprocess_txt_action == 'copy' and existing_caption: elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption caption = existing_caption
@ -174,7 +174,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.src = filename params.src = filename
existing_caption = None existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt' existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
if os.path.exists(existing_caption_filename): if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file: with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read() existing_caption = file.read()

View File

@ -69,7 +69,7 @@ class Embedding:
'hash': self.checksum(), 'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict, 'optimizer_state_dict': self.optimizer_state_dict,
} }
torch.save(optimizer_saved_dict, filename + '.optim') torch.save(optimizer_saved_dict, f"{filename}.optim")
def checksum(self): def checksum(self):
if self.cached_checksum is not None: if self.cached_checksum is not None:
@ -437,8 +437,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state: if shared.opts.save_optimizer_state:
optimizer_state_dict = None optimizer_state_dict = None
if os.path.exists(filename + '.optim'): if os.path.exists(f"{filename}.optim"):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None): if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
@ -599,7 +599,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
data = torch.load(last_saved_file) data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data)) info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???')) title = f"<{data.get('name', '???')}>"
try: try:
vectorSize = list(data['string_to_param'].values())[0].shape[0] vectorSize = list(data['string_to_param'].values())[0].shape[0]
@ -608,8 +608,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.shorthash) footer_mid = f'[{checkpoint.shorthash}]'
footer_right = '{}v {}s'.format(vectorSize, steps_done) footer_right = f'{vectorSize}v {steps_done}s'
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data) captioned_image = insert_image_data_embed(captioned_image, data)

View File

@ -101,7 +101,7 @@ def visit(x, func, path=""):
for c in x.children: for c in x.children:
visit(c, func, path) visit(c, func, path)
elif x.label is not None: elif x.label is not None:
func(path + "/" + str(x.label), x) func(f"{path}/{x.label}", x)
def add_style(name: str, prompt: str, negative_prompt: str): def add_style(name: str, prompt: str, negative_prompt: str):
@ -166,7 +166,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image) img = Image.open(image)
filename = os.path.basename(image) filename = os.path.basename(image)
left, _ = os.path.splitext(filename) left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
return [gr.update(), None] return [gr.update(), None]
@ -182,29 +182,29 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface): def create_seed_inputs(target_interface):
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"): with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
seed.style(container=False) seed.style(container=False)
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed') random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed') reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
# Components to show/hide based on the 'Extra' checkbox # Components to show/hide based on the 'Extra' checkbox
seed_extras = [] seed_extras = []
with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1) seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
subseed.style(container=False) subseed.style(container=False)
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed') random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed') reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
with FormRow(visible=False) as seed_extra_row_2: with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2) seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@ -765,7 +765,7 @@ def create_ui():
) )
button.click( button.click(
fn=lambda: None, fn=lambda: None,
_js="switch_to_"+name.replace(" ", "_"), _js=f"switch_to_{name.replace(' ', '_')}",
inputs=[], inputs=[],
outputs=[], outputs=[],
) )
@ -828,7 +828,7 @@ def create_ui():
with FormGroup(): with FormGroup():
with FormRow(): with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed": elif category == "seed":
@ -1462,18 +1462,18 @@ def create_ui():
elif t == bool: elif t == bool:
comp = gr.Checkbox comp = gr.Checkbox
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {t} for key {key}')
elem_id = "setting_"+key elem_id = f"setting_{key}"
if info.refresh is not None: if info.refresh is not None:
if is_quicksettings: if is_quicksettings:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
else: else:
with FormRow(): with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
else: else:
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
@ -1525,7 +1525,7 @@ def create_ui():
result = gr.HTML(elem_id="settings_result") result = gr.HTML(elem_id="settings_result")
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] quicksettings_names = opts.quicksettings_list
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = [] quicksettings_list = []
@ -1545,7 +1545,7 @@ def create_ui():
current_tab.__exit__() current_tab.__exit__()
gr.Group() gr.Group()
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
current_tab.__enter__() current_tab.__enter__()
current_row = gr.Column(variant='compact') current_row = gr.Column(variant='compact')
current_row.__enter__() current_row.__enter__()
@ -1566,7 +1566,7 @@ def create_ui():
current_row.__exit__() current_row.__exit__()
current_tab.__exit__() current_tab.__exit__()
with gr.TabItem("Actions", id="actions"): with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization") download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
@ -1574,7 +1574,7 @@ def create_ui():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses", id="licenses"): with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@ -1664,7 +1664,7 @@ def create_ui():
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
if label in shared.opts.hidden_tabs: if label in shared.opts.hidden_tabs:
continue continue
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render() interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")):
@ -1693,11 +1693,9 @@ def create_ui():
show_progress=info.refresh is not None, show_progress=info.refresh is not None,
) )
text_settings.change( update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"), text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
inputs=[], demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
outputs=[image_cfg_scale],
)
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click( button_set_checkpoint.click(
@ -1773,10 +1771,10 @@ def create_ui():
def loadsave(path, x): def loadsave(path, x):
def apply_field(obj, field, condition=None, init_field=None): def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field key = f"{path}/{field}"
if getattr(obj, 'custom_script_source', None) is not None: if getattr(obj, 'custom_script_source', None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key key = f"customscript/{obj.custom_script_source}/{key}"
if getattr(obj, 'do_not_save_to_config', False): if getattr(obj, 'do_not_save_to_config', False):
return return
@ -1925,7 +1923,7 @@ def versions_html():
python_version = ".".join([str(x) for x in sys.version_info[0:3]]) python_version = ".".join([str(x) for x in sys.version_info[0:3]])
commit = launch.commit_hash() commit = launch.commit_hash()
short_commit = commit[0:8] tag = launch.git_tag()
if shared.xformers_available: if shared.xformers_available:
import xformers import xformers
@ -1934,6 +1932,8 @@ def versions_html():
xformers_version = "N/A" xformers_version = "N/A"
return f""" return f"""
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
  
python: <span title="{sys.version}">{python_version}</span> python: <span title="{sys.version}">{python_version}</span>
     
torch: {getattr(torch, '__long_version__',torch.__version__)} torch: {getattr(torch, '__long_version__',torch.__version__)}
@ -1942,7 +1942,21 @@ xformers: {xformers_version}
     
gradio: {gr.__version__} gradio: {gr.__version__}
     
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
  
checkpoint: <a id="sd_checkpoint_hash">N/A</a> checkpoint: <a id="sd_checkpoint_hash">N/A</a>
""" """
def setup_ui_api(app):
from pydantic import BaseModel, Field
from typing import List
class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field")
label: str = Field(title="Label of the quicksettings field")
def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])

View File

@ -61,7 +61,8 @@ def save_config_state(name):
if not name: if not name:
name = "Config" name = "Config"
current_config_state["name"] = name current_config_state["name"] = name
filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json") timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.") print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f) json.dump(current_config_state, f)

View File

@ -69,7 +69,9 @@ class ExtraNetworksPage:
pass pass
def link_preview(self, filename): def link_preview(self, filename):
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename)) quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None): def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename) abspath = os.path.abspath(filename)
@ -89,19 +91,22 @@ class ExtraNetworksPage:
subdirs = {} subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True): for root, dirs, files in os.walk(parentdir):
if not os.path.isdir(x): for dirname in dirs:
continue x = os.path.join(root, dirname)
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") if not os.path.isdir(x):
while subdir.startswith("/"): continue
subdir = subdir[1:]
is_empty = len(os.listdir(x)) == 0 subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
if not is_empty and not subdir.endswith("/"): while subdir.startswith("/"):
subdir = subdir + "/" subdir = subdir[1:]
subdirs[subdir] = 1 is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"):
subdir = subdir + "/"
subdirs[subdir] = 1
if subdirs: if subdirs:
subdirs = {"": 1, **subdirs} subdirs = {"": 1, **subdirs}
@ -157,8 +162,20 @@ class ExtraNetworksPage:
if metadata: if metadata:
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>" metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
local_path = ""
filename = item.get("filename", "")
for reldir in self.allowed_directories_for_previews():
absdir = os.path.abspath(reldir)
if filename.startswith(absdir):
local_path = filename[len(absdir):]
# if this is true, the item must not be show in the default view, and must instead only be
# shown when searching for it
serach_only = "/." in local_path or "\\." in local_path
args = { args = {
"style": f"'{height}{width}{background_image}'", "style": f"'display: none; {height}{width}{background_image}'",
"prompt": item.get("prompt", None), "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname), "tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]), "local_preview": json.dumps(item["local_preview"]),
@ -168,6 +185,7 @@ class ExtraNetworksPage:
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""), "search_term": item.get("search_term", ""),
"metadata_button": metadata_button, "metadata_button": metadata_button,
"serach_only": " search_only" if serach_only else "",
} }
return self.card_page.format(**args) return self.card_page.format(**args)
@ -209,6 +227,11 @@ def intialize():
class ExtraNetworksUi: class ExtraNetworksUi:
def __init__(self): def __init__(self):
self.pages = None self.pages = None
"""gradio HTML components related to extra networks' pages"""
self.page_contents = None
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
self.stored_extra_pages = None self.stored_extra_pages = None
self.button_save_preview = None self.button_save_preview = None
@ -236,17 +259,22 @@ def pages_in_preferred_order(pages):
def create_ui(container, button, tabname): def create_ui(container, button, tabname):
ui = ExtraNetworksUi() ui = ExtraNetworksUi()
ui.pages = [] ui.pages = []
ui.pages_contents = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")): page_id = page.title.lower().replace(" ", "_")
page_elem = gr.HTML(page.create_html(ui.tabname)) with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
page_elem = gr.HTML('', elem_id=elem_id)
ui.pages.append(page_elem) ui.pages.append(page_elem)
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@ -254,19 +282,22 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible): def toggle_visibility(is_visible):
is_visible = not is_visible is_visible = not is_visible
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
if is_visible and not ui.pages_contents:
refresh()
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
state_visible = gr.State(value=False) state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button]) button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
def refresh(): def refresh():
res = []
for pg in ui.stored_extra_pages: for pg in ui.stored_extra_pages:
pg.refresh() pg.refresh()
res.append(pg.create_html(ui.tabname))
return res ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
return ui.pages_contents
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)

View File

@ -36,7 +36,7 @@ def save_pil_to_file(pil_image, dir=None):
if already_saved_as and os.path.isfile(already_saved_as): if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as) register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as) file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
return file_obj return file_obj
if shared.opts.temp_dir != "": if shared.opts.temp_dir != "":

View File

@ -5,7 +5,7 @@ basicsr
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.28.1 gradio==3.29.0
numpy numpy
omegaconf omegaconf
opencv-contrib-python opencv-contrib-python

View File

@ -3,7 +3,7 @@ transformers==4.25.1
accelerate==0.18.0 accelerate==0.18.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.28.1 gradio==3.29.0
numpy==1.23.5 numpy==1.23.5
Pillow==9.4.0 Pillow==9.4.0
realesrgan==0.3.0 realesrgan==0.3.0

View File

@ -77,7 +77,7 @@ return process_images(p)
module.display = display module.display = display
indent = " " * indent_level indent = " " * indent_level
indented = code.replace('\n', '\n' + indent) indented = code.replace('\n', f"\n{indent}")
body = f"""def __webuitemp__(): body = f"""def __webuitemp__():
{indent}{indented} {indent}{indented}
__webuitemp__()""" __webuitemp__()"""

View File

@ -84,7 +84,7 @@ class Script(scripts.Script):
p.color_corrections = initial_color_corrections p.color_corrections = initial_color_corrections
if append_interrogation != "None": if append_interrogation != "None":
p.prompt = original_prompt + ", " if original_prompt != "" else "" p.prompt = f"{original_prompt}, " if original_prompt else ""
if append_interrogation == "CLIP": if append_interrogation == "CLIP":
p.prompt += shared.interrogator.interrogate(p.init_images[0]) p.prompt += shared.interrogator.interrogate(p.init_images[0])
elif append_interrogation == "DeepBooru": elif append_interrogation == "DeepBooru":

View File

@ -100,11 +100,10 @@ def cmdargs(line):
def load_prompt_file(file): def load_prompt_file(file):
if file is None: if file is None:
lines = [] return None, gr.update(), gr.update(lines=7)
else: else:
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
return None, "\n".join(lines), gr.update(lines=7)
return None, "\n".join(lines), gr.update(lines=7)
class Script(scripts.Script): class Script(scripts.Script):
@ -118,12 +117,12 @@ class Script(scripts.Script):
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)
# We start at one line. When the text changes, we jump to seven lines, or two lines if no \n. # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed. # be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):

View File

@ -222,7 +222,7 @@ axis_options = [
AxisOption("Denoising", float, apply_field("denoising_strength")), AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
@ -439,7 +439,7 @@ class Script(scripts.Script):
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown]) z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params): def get_dropdown_update_from_params(axis,params):
val_key = axis + " Values" val_key = f"{axis} Values"
vals = params.get(val_key,"") vals = params.get(val_key,"")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist) return gr.update(value = valslist)

View File

@ -125,6 +125,10 @@ div.gradio-html.min{
text-decoration: none; text-decoration: none;
} }
a{
font-weight: bold;
cursor: pointer;
}
/* general styled components */ /* general styled components */
@ -246,7 +250,7 @@ button.custom-button{
} }
} }
#txt2img_gallery img, #img2img_gallery img{ #txt2img_gallery img, #img2img_gallery img, #extras_gallery img{
object-fit: scale-down; object-fit: scale-down;
} }
#txt2img_actions_column, #img2img_actions_column { #txt2img_actions_column, #img2img_actions_column {
@ -397,6 +401,18 @@ div#extras_scale_to_tab div.form{
margin: 0 1.2em; margin: 0 1.2em;
} }
table.settings-value-table{
background: white;
border-collapse: collapse;
margin: 1em;
border: 4px solid white;
}
table.settings-value-table td{
padding: 0.4em;
border: 1px solid #ccc;
max-width: 36em;
}
/* live preview */ /* live preview */
.progressDiv{ .progressDiv{
@ -534,6 +550,8 @@ div#extras_scale_to_tab div.form{
#lightboxModal > img.modalImageFullscreen{ #lightboxModal > img.modalImageFullscreen{
object-fit: contain; object-fit: contain;
height: 100%; height: 100%;
width: 100%;
min-height: 0;
} }
.modalPrev, .modalPrev,

View File

@ -6,6 +6,8 @@ import signal
import re import re
import warnings import warnings
import json import json
from threading import Thread
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
@ -185,24 +187,19 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
startup_timer.record("load scripts") startup_timer.record("load scripts")
modelloader.load_upscalers()
#startup_timer.record("load upscalers") #Is this necessary? I don't know.
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE") startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates() modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates") startup_timer.record("refresh textual inversion templates")
try: # load model in parallel to other startup stuff
modules.sd_models.load_model() Thread(target=lambda: shared.sd_model).start()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
startup_timer.record("load SD checkpoint")
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
@ -286,7 +283,6 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.") print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(): def webui():
launch_api = cmd_opts.api launch_api = cmd_opts.api
initialize() initialize()
@ -313,6 +309,16 @@ def webui():
for line in file.readlines(): for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
# this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'):
def fastapi_setup(self):
self.docs_url = "/docs"
self.redoc_url = "/redoc"
self.original_setup()
FastAPI.original_setup = FastAPI.setup
FastAPI.setup = fastapi_setup
app, local_url, share_url = shared.demo.launch( app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name=server_name, server_name=server_name,
@ -339,6 +345,7 @@ def webui():
setup_middleware(app) setup_middleware(app)
modules.progress.setup_progress_api(app) modules.progress.setup_progress_api(app)
modules.ui.setup_ui_api(app)
if launch_api: if launch_api:
create_api(app) create_api(app)
@ -350,6 +357,11 @@ def webui():
print(f"Startup time: {startup_timer.summary()}.") print(f"Startup time: {startup_timer.summary()}.")
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...') print('Restarting UI...')

View File

@ -153,24 +153,31 @@ else
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
fi fi
printf "\n%s\n" "${delimiter}" if [[ -z "${VIRTUAL_ENV}" ]];
printf "Create and activate python venv"
printf "\n%s\n" "${delimiter}"
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
if [[ ! -d "${venv_dir}" ]]
then then
"${python_cmd}" -m venv "${venv_dir}" printf "\n%s\n" "${delimiter}"
first_launch=1 printf "Create and activate python venv"
fi printf "\n%s\n" "${delimiter}"
# shellcheck source=/dev/null cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
if [[ -f "${venv_dir}"/bin/activate ]] if [[ ! -d "${venv_dir}" ]]
then then
source "${venv_dir}"/bin/activate "${python_cmd}" -m venv "${venv_dir}"
first_launch=1
fi
# shellcheck source=/dev/null
if [[ -f "${venv_dir}"/bin/activate ]]
then
source "${venv_dir}"/bin/activate
else
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
fi
else else
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m" printf "python venv already activate: ${VIRTUAL_ENV}"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
exit 1
fi fi
# Try using TCMalloc on Linux # Try using TCMalloc on Linux