Merge branch 'dev' into release_candidate

This commit is contained in:
AUTOMATIC1111 2024-07-20 11:51:12 +03:00
commit 5bbbda473f
27 changed files with 324 additions and 107 deletions

2
.gitignore vendored
View File

@ -2,6 +2,7 @@ __pycache__
*.ckpt
*.safetensors
*.pth
.DS_Store
/ESRGAN/*
/SwinIR/*
/repositories
@ -40,3 +41,4 @@ notification.mp3
/test/test_outputs
/cache
trace.json
/sysinfo-????-??-??-??-??.json

View File

@ -128,10 +128,32 @@ sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based:
sudo pacman -S wget git python3
```
If your system is very new, you need to install python3.11 or python3.10:
```bash
# Ubuntu 24.04
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt update
sudo apt install python3.11
# Manjaro/Arch
sudo pacman -S yay
yay -S python311 # do not confuse with python3.11 package
# Only for 3.11
# Then set up env variable in launch script
export python_cmd="python3.11"
# or in webui-user.sh
python_cmd="python3.11"
```
2. Navigate to the directory you would like the webui to be installed and execute the following command:
```bash
wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
```
Or just clone the repo wherever you want:
```bash
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui
```
3. Run `webui.sh`.
4. Check `webui-user.sh` for options.
### Installation on Apple Silicon

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared
import modules.models.sd3.mmdit
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@ -114,7 +115,10 @@ class NetworkModule:
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if hasattr(self.sd_module, 'weight'):
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
s = self.sd_module.weight.shape
self.shape = (s[0] // 3, s[1])
elif hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA

View File

@ -1,6 +1,7 @@
import torch
import lyco_helpers
import modules.models.sd3.mmdit
import network
from modules import devices
@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
w = weights.w.copy()
weights.w.clear()
weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
return NetworkModuleLora(net, weights)
return None
@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:

View File

@ -20,6 +20,7 @@ from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.models.sd3.mmdit
from lora_logger import logger
@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
if hasattr(shared.sd_model, 'diffusers_weight_map'):
diffusers_weight_map = shared.sd_model.diffusers_weight_map
elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
diffusers_weight_map = {}
for k, v in shared.sd_model.diffusers_weight_mapping():
diffusers_weight_map[k] = v
shared.sd_model.diffusers_weight_map = diffusers_weight_map
else:
diffusers_weight_map = None
matched_networks = {}
bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, _, network_part = key_network.partition(".")
if diffusers_weight_map:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
else:
key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
if diffusers_weight_map:
key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
else:
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory()
def allowed_layer_without_weight(layer):
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
return True
return False
def store_weights_backup(weight):
if weight is None:
return None
return weight.to(devices.cpu, copy=True)
def restore_weights_backup(obj, field, weight):
if weight is None:
setattr(obj, field, None)
return
getattr(obj, field).copy_(weight)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
else:
self.weight.copy_(weights_backup)
restore_weights_backup(self, 'weight', weights_backup)
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias.copy_(bias_backup)
else:
self.bias.copy_(bias_backup)
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self.out_proj, 'bias', bias_backup)
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
restore_weights_backup(self, 'bias', bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")
if current_names != () and not allowed_layer_without_weight(self):
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
weights_backup = store_weights_backup(self.weight)
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
bias_backup = store_weights_backup(self.out_proj.bias)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
bias_backup = store_weights_backup(self.bias)
else:
bias_backup = None
@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup
if current_names != wanted_names:
@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
qw, kw, vw = self.weight.chunk(3, 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if module is None:
continue

View File

@ -113,7 +113,7 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
if image.mode == "RGBA":
if image.mode in ("RGBA", "P"):
image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({

View File

@ -47,6 +47,22 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, **kwargs):
try:
res = func(*args, **kwargs)
finally:
shared.state.skipped = False
shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0
shared.state.job = ""
return res
return wrap_gradio_call_no_job(f, extra_outputs, add_stats)
def wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
@ -66,9 +82,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
@ -77,11 +90,6 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
devices.torch_gc()
shared.state.skipped = False
shared.state.interrupted = False
shared.state.stopping_generation = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
@ -123,3 +131,4 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f

View File

@ -146,18 +146,19 @@ def connect_paste_params_buttons():
destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
if binding.source_image_component and destination_image_component:
need_send_dementions = destination_width_component and binding.tabname != 'inpaint'
if isinstance(binding.source_image_component, gr.Gallery):
func = send_image_and_dimensions if destination_width_component else image_from_url_text
func = send_image_and_dimensions if need_send_dementions else image_from_url_text
jsfunc = "extract_image_from_gallery"
else:
func = send_image_and_dimensions if destination_width_component else lambda x: x
func = send_image_and_dimensions if need_send_dementions else lambda x: x
jsfunc = None
binding.paste_button.click(
fn=func,
_js=jsfunc,
inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],
show_progress=False,
)

View File

@ -446,7 +446,6 @@ def prepare_environment():
exit(0)
def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")

View File

@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
#################################################################################
class QkvLinear(torch.nn.Linear):
pass
def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
if not pre_only:
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
assert attn_mode in self.ATTENTION_MODES

View File

@ -5,6 +5,8 @@ import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
from modules import sd_hijack
#################################################################################################
### Core/Utility
@ -110,9 +112,9 @@ class CLIPEncoder(torch.nn.Module):
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
@ -127,7 +129,7 @@ class CLIPTextModel_(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)

View File

@ -40,6 +40,7 @@ CLIPG_CONFIG = {
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
"textual_inversion_key": "clip_g",
}
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
@ -204,7 +205,10 @@ class SD3Cond(torch.nn.Module):
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
def tokenize(self, texts):
return self.model_lg.tokenize(texts)
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]

View File

@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
}
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
self.depth = depth
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()

View File

@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16
def diffusers_weight_mapping(self):
for i in range(self.model.depth):
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"

View File

@ -359,13 +359,28 @@ class EmbeddingsWithFixes(torch.nn.Module):
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
vecs.append(tensor)
return torch.stack(vecs)
class TextualInversionEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
super().__init__(num_embeddings, embedding_dim, **kwargs)
self.embeddings = model_hijack
self.textual_inversion_key = textual_inversion_key
@property
def wrapped(self):
return super().forward
def forward(self, input_ids):
return EmbeddingsWithFixes.forward(self, input_ids)
def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__

View File

@ -120,6 +120,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap
if scheduler.label == 'Beta':
p.extra_generation_params["Beta schedule alpha"] = opts.beta_dist_alpha
p.extra_generation_params["Beta schedule beta"] = opts.beta_dist_beta
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)
if discard_next_to_last_sigma:

View File

@ -2,6 +2,7 @@ import dataclasses
import torch
import k_diffusion
import numpy as np
from scipy import stats
from modules import shared
@ -115,6 +116,17 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
return torch.FloatTensor(sigs).to(device)
def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
alpha = shared.opts.beta_dist_alpha
beta = shared.opts.beta_dist_beta
timesteps = 1 - np.linspace(0, 1, n)
timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]
sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]
sigmas += [0.0]
return torch.FloatTensor(sigmas).to(device)
schedulers = [
Scheduler('automatic', 'Automatic', None),
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
@ -127,6 +139,7 @@ schedulers = [
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
]
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

View File

@ -64,6 +64,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"save_write_log_csv": OptionInfo(True, "Write log.csv when saving images using 'Save' button"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
@ -404,6 +405,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models"),
'skip_early_cond': OptionInfo(0.0, "Ignore negative prompt during early sampling", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext="Skip Early CFG").info("disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling"),
'beta_dist_alpha': OptionInfo(0.6, "Beta scheduler - alpha", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),
'beta_dist_beta': OptionInfo(0.6, "Beta scheduler - beta", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),
}))
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {

View File

@ -162,7 +162,7 @@ class State:
errors.record_exception()
def assign_current_image(self, image):
if shared.opts.live_previews_image_format == 'jpeg' and image.mode == 'RGBA':
if shared.opts.live_previews_image_format == 'jpeg' and image.mode in ('RGBA', 'P'):
image = image.convert('RGB')
self.current_image = image
self.id_live_preview += 1

View File

@ -1,15 +1,13 @@
import json
import os
import sys
import subprocess
import platform
import hashlib
import pkg_resources
import psutil
import re
from pathlib import Path
import launch
from modules import paths_internal, timer, shared, extensions, errors
from modules import paths_internal, timer, shared_cmd_options, errors, launch_utils
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = {
@ -69,14 +67,46 @@ def check(x):
return h.hexdigest() == m.group(1)
def get_dict():
ram = psutil.virtual_memory()
def get_cpu_info():
cpu_info = {"model": platform.processor()}
try:
import psutil
cpu_info["count logical"] = psutil.cpu_count(logical=True)
cpu_info["count physical"] = psutil.cpu_count(logical=False)
except Exception as e:
cpu_info["error"] = str(e)
return cpu_info
def get_ram_info():
try:
import psutil
ram = psutil.virtual_memory()
return {x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0}
except Exception as e:
return str(e)
def get_packages():
try:
return subprocess.check_output([sys.executable, '-m', 'pip', 'freeze', '--all']).decode("utf8").splitlines()
except Exception as pip_error:
try:
import importlib.metadata
packages = importlib.metadata.distributions()
return sorted([f"{package.metadata['Name']}=={package.version}" for package in packages])
except Exception as e2:
return {'error pip': pip_error, 'error importlib': str(e2)}
def get_dict():
config = get_config()
res = {
"Platform": platform.platform(),
"Python": platform.python_version(),
"Version": launch.git_tag(),
"Commit": launch.commit_hash(),
"Version": launch_utils.git_tag(),
"Commit": launch_utils.commit_hash(),
"Git status": git_status(paths_internal.script_path),
"Script path": paths_internal.script_path,
"Data path": paths_internal.data_path,
"Extensions dir": paths_internal.extensions_dir,
@ -84,20 +114,14 @@ def get_dict():
"Commandline": get_argv(),
"Torch env info": get_torch_sysinfo(),
"Exceptions": errors.get_exceptions(),
"CPU": {
"model": platform.processor(),
"count logical": psutil.cpu_count(logical=True),
"count physical": psutil.cpu_count(logical=False),
},
"RAM": {
x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
},
"Extensions": get_extensions(enabled=True),
"Inactive extensions": get_extensions(enabled=False),
"CPU": get_cpu_info(),
"RAM": get_ram_info(),
"Extensions": get_extensions(enabled=True, fallback_disabled_extensions=config.get('disabled_extensions', [])),
"Inactive extensions": get_extensions(enabled=False, fallback_disabled_extensions=config.get('disabled_extensions', [])),
"Environment": get_environment(),
"Config": get_config(),
"Config": config,
"Startup": timer.startup_record,
"Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
"Packages": get_packages(),
}
return res
@ -111,11 +135,11 @@ def get_argv():
res = []
for v in sys.argv:
if shared.cmd_opts.gradio_auth and shared.cmd_opts.gradio_auth == v:
if shared_cmd_options.cmd_opts.gradio_auth and shared_cmd_options.cmd_opts.gradio_auth == v:
res.append("<hidden>")
continue
if shared.cmd_opts.api_auth and shared.cmd_opts.api_auth == v:
if shared_cmd_options.cmd_opts.api_auth and shared_cmd_options.cmd_opts.api_auth == v:
res.append("<hidden>")
continue
@ -123,6 +147,7 @@ def get_argv():
return res
re_newline = re.compile(r"\r*\n")
@ -136,25 +161,55 @@ def get_torch_sysinfo():
return str(e)
def get_extensions(*, enabled):
def run_git(path, *args):
try:
def to_json(x: extensions.Extension):
return {
"name": x.name,
"path": x.path,
"version": x.version,
"branch": x.branch,
"remote": x.remote,
}
return subprocess.check_output([launch_utils.git, '-C', path, *args], shell=False, encoding='utf8').strip()
except Exception as e:
return str(e)
return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
def git_status(path):
if (Path(path) / '.git').is_dir():
return run_git(paths_internal.script_path, 'status')
def get_info_from_repo_path(path: Path):
is_repo = (path / '.git').is_dir()
return {
'name': path.name,
'path': str(path),
'commit': run_git(path, 'rev-parse', 'HEAD') if is_repo else None,
'branch': run_git(path, 'branch', '--show-current') if is_repo else None,
'remote': run_git(path, 'remote', 'get-url', 'origin') if is_repo else None,
}
def get_extensions(*, enabled, fallback_disabled_extensions=None):
try:
from modules import extensions
if extensions.extensions:
def to_json(x: extensions.Extension):
return {
"name": x.name,
"path": x.path,
"commit": x.commit_hash,
"branch": x.branch,
"remote": x.remote,
}
return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
else:
return [get_info_from_repo_path(d) for d in Path(paths_internal.extensions_dir).iterdir() if d.is_dir() and enabled != (str(d.name) in fallback_disabled_extensions)]
except Exception as e:
return str(e)
def get_config():
try:
from modules import shared
return shared.opts.data
except Exception as e:
return str(e)
except Exception as _:
try:
with open(shared_cmd_options.cmd_opts.ui_settings_file, 'r') as f:
return json.load(f)
except Exception as e:
return str(e)

View File

@ -10,7 +10,7 @@ import gradio as gr
import gradio.utils
import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401
from modules import gradio_extensons, sd_schedulers # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
@ -622,8 +622,8 @@ def create_ui():
with gr.Column(elem_id="img2img_column_size", scale=4):
selected_scale_tab = gr.Number(value=0, visible=False)
with gr.Tabs():
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with gr.Tabs(elem_id="img2img_tabs_resize"):
with gr.Tab(label="Resize to", id="to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@ -632,7 +632,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
with gr.Tab(label="Resize by", id="by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
@ -889,7 +889,7 @@ def create_ui():
))
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)

View File

@ -3,6 +3,7 @@ import dataclasses
import json
import html
import os
from contextlib import nullcontext
import gradio as gr
@ -103,14 +104,15 @@ def save_files(js_data, images, do_make_zip, index):
# NOTE: ensure csv integrity when fields are added by
# updating headers and padding with delimiters where needed
if os.path.exists(logfile_path):
if shared.opts.save_write_log_csv and os.path.exists(logfile_path):
update_logfile(logfile_path, fields)
with open(logfile_path, "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
writer.writerow(fields)
with (open(logfile_path, "a", encoding="utf8", newline='') if shared.opts.save_write_log_csv else nullcontext()) as file:
if file:
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
writer.writerow(fields)
for image_index, filedata in enumerate(images, start_index):
image = image_from_url_text(filedata)
@ -130,7 +132,8 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
if file:
writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
# Make Zip
if do_make_zip:
@ -228,7 +231,7 @@ def create_output_panel(tabname, outdir, toprow=None):
)
save.click(
fn=call_queue.wrap_gradio_call(save_files),
fn=call_queue.wrap_gradio_call_no_job(save_files),
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[
res.generation_info,
@ -244,7 +247,7 @@ def create_output_panel(tabname, outdir, toprow=None):
)
save_zip.click(
fn=call_queue.wrap_gradio_call(save_files),
fn=call_queue.wrap_gradio_call_no_job(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[
res.generation_info,

View File

@ -624,37 +624,37 @@ def create_ui():
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, extensions_table, install_result],
)
search_extensions_text.change(
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(search_extensions, extra_outputs=[gr.update()]),
inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column],
outputs=[available_extensions_table, install_result],
)
selected_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
showing_type.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
filtering_type.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
sort_column.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
@ -667,7 +667,7 @@ def create_ui():
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
fn=modules.ui.wrap_gradio_call_no_job(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch],
outputs=[install_url, extensions_table, install_result],
)

View File

@ -1,7 +1,7 @@
import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
from modules.call_queue import wrap_gradio_call
from modules.call_queue import wrap_gradio_call_no_job
from modules.options import options_section
from modules.shared import opts
from modules.ui_components import FormRow
@ -295,7 +295,7 @@ class UiSettings:
def add_functionality(self, demo):
self.submit.click(
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
fn=wrap_gradio_call_no_job(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
inputs=self.components,
outputs=[self.text_settings, self.result],
)

View File

@ -56,8 +56,8 @@ class Upscaler:
dest_w = int((img.width * scale) // 8 * 8)
dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3):
if img.width >= dest_w and img.height >= dest_h and scale != 1:
for i in range(3):
if img.width >= dest_w and img.height >= dest_h and (i > 0 or scale != 1):
break
if shared.state.interrupted:

View File

@ -118,7 +118,7 @@ def apply_size(p, x: str, xs) -> None:
def find_vae(name: str):
if name := name.strip().lower() in ('auto', 'automatic'):
if (name := name.strip().lower()) in ('auto', 'automatic'):
return 'Automatic'
elif name == 'none':
return 'None'
@ -259,6 +259,8 @@ axis_options = [
AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Beta schedule alpha", float, apply_override("beta_dist_alpha")),
AxisOption("Beta schedule beta", float, apply_override("beta_dist_beta")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
AxisOption("Denoising", float, apply_field("denoising_strength")),

View File

@ -48,6 +48,7 @@ echo Warning: Failed to upgrade PIP version
:activate_venv
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
call "%VENV_DIR%\Scripts\activate.bat"
echo venv %PYTHON%
:skip_venv