Merge branch 'AUTOMATIC1111:master' into img2img-api-scripts

This commit is contained in:
noodleanon 2023-01-07 14:18:09 +00:00 committed by GitHub
commit 50e2536279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 813 additions and 265 deletions

View File

@ -1,9 +1,7 @@
# Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
![](screenshot.png)
## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab):
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
### Automatic Installation on Linux
1. Install the dependencies:
@ -141,6 +138,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot

View File

@ -184,7 +184,7 @@ SOFTWARE.
</pre>
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
<small>Code added by contirubtors, most likely copied from this repository.</small>
<small>Code added by contributors, most likely copied from this repository.</small>
<pre>
Apache License
@ -390,3 +390,30 @@ SOFTWARE.
limitations under the License.
</pre>
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
<pre>
MIT License
Copyright (c) 2023 Alex Birch
Copyright (c) 2023 Amin Rezaei
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
</pre>

View File

@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")

View File

@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
orig_cumsum = torch.cumsum
orig_Tensor_cumsum = torch.Tensor.cumsum
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
return cumsum_func(input, *args, **kwargs)
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
elif version.parse(torch.__version__) > version.parse("1.13.1"):
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )

View File

@ -13,7 +13,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
@ -457,7 +457,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)

View File

@ -711,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = 0
self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
if self.hr_resize_x == 0 and self.hr_resize_y == 0:

View File

@ -71,6 +71,7 @@ callback_map = dict(
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
callbacks_script_unloaded=[],
)
@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
try:
c.callback()
except Exception:
report_exception(c, 'script_unloaded')
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@ -202,7 +211,7 @@ def on_app_started(callback):
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback)
@ -279,3 +288,10 @@ def on_image_grid(callback):
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
def on_script_unloaded(callback):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback)

View File

@ -290,7 +290,6 @@ class ScriptRunner:
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:

View File

@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@ -43,20 +41,19 @@ def apply_optimizations():
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
else:
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
@ -150,10 +147,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
def get_prompt_lengths(self, text):
_, token_count = self.clip.process_texts([text])
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):

View File

@ -1,30 +1,89 @@
import math
from collections import namedtuple
import torch
from modules import prompt_parser, devices
from modules import prompt_parser, devices, sd_hijack
from modules.shared import opts
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
class PromptChunk:
"""
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
so just 75 tokens from prompt.
"""
def __init__(self):
self.tokens = []
self.multipliers = []
self.fixes = []
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
chunk = PromptChunk()
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
chunk.multipliers = [1.0] * (self.chunk_length + 2)
return chunk
def get_target_prompt_token_count(self, token_count):
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
"""Converts a batch of texts into a batch of token ids"""
raise NotImplementedError
def encode_with_transformers(self, tokens):
"""
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
All python lists with tokens are assumed to have same length, usually 77.
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
model - can be 768 and 1024.
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
"""
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
raise NotImplementedError
def tokenize_line(self, line, used_custom_terms, hijack_comments):
def tokenize_line(self, line):
"""
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
represent the prompt.
Returns the list and the total number of tokens in the prompt.
"""
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@ -32,205 +91,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
fixes = []
remade_tokens = []
multipliers = []
chunks = []
chunk = PromptChunk()
token_count = 0
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
def next_chunk():
"""puts current chunk into the list of results and produces the next one - empty"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
token_count += len(chunk.tokens)
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
chunk.multipliers += [1.0] * to_add
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
last_comma = -1
chunks.append(chunk)
chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
position = 0
while position < len(tokens):
token = tokens[position]
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
last_comma = len(chunk.tokens)
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
break_location = last_comma + 1
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [self.id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
reloc_tokens = chunk.tokens[break_location:]
reloc_mults = chunk.multipliers[break_location:]
chunk.tokens = chunk.tokens[:break_location]
chunk.multipliers = chunk.multipliers[:break_location]
next_chunk()
chunk.tokens = reloc_tokens
chunk.multipliers = reloc_mults
if len(chunk.tokens) == self.chunk_length:
next_chunk()
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [self.id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
chunk.tokens.append(token)
chunk.multipliers.append(weight)
position += 1
continue
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
emb_len = int(embedding.vec.shape[0])
if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk()
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
return remade_tokens, fixes, multipliers, token_count
chunk.tokens += [0] * emb_len
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
next_chunk()
return chunks, token_count
def process_texts(self, texts):
"""
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
length, in tokens, of all texts.
"""
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
batch_chunks = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
chunks = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
cache[line] = chunks
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
batch_chunks.append(chunks)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
return batch_chunks, token_count
def process_text_old(self, texts):
id_start = self.id_start
id_end = self.id_end
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
def forward(self, texts):
"""
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
An example shape returned by this function can be: (2, 77, 768).
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
cache = {}
batch_tokens = self.tokenize(texts)
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
batch_chunks, token_count = self.process_texts(texts)
i = 0
while i < len(tokens):
token = tokens[i]
used_embeddings = {}
chunk_count = max([len(x) for x in batch_chunks])
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
zs = []
for i in range(chunk_count):
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
tokens = [x.tokens for x in batch_chunk]
multipliers = [x.multipliers for x in batch_chunk]
self.hijack.fixes = [x.fixes for x in batch_chunk]
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
for fixes in self.hijack.fixes:
for position, embedding in fixes:
used_embeddings[embedding.name] = embedding
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
z = self.process_tokens(tokens, multipliers)
zs.append(z)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
if len(used_embeddings) > 0:
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.id_end] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
"""
sends one single prompt chunk to be encoded by transformers neural network.
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
corresponds to one token.
"""
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@ -239,8 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()

View File

@ -0,0 +1,81 @@
from modules import sd_hijack_clip
from modules import shared
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
id_start = self.id_start
id_end = self.id_end
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.tokenize(texts)
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)

View File

@ -1,7 +1,7 @@
import math
import sys
import traceback
import importlib
import psutil
import torch
from torch import einsum
@ -12,6 +12,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
def get_available_vram():
if shared.device.type == 'cuda':
stats = torch.cuda.memory_stats(shared.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
else:
return psutil.virtual_memory().available
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
def check_for_psutil():
try:
spec = importlib.util.find_spec('psutil')
return spec is not None
except ModuleNotFoundError:
return False
invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from https://github.com/invoke-ai/InvokeAI --
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
def sub_quad_attention_forward(self, x, context=None, mask=None):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
x = out_proj(x)
x = dropout(x)
return x
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else:
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min = kv_chunk_size_min,
use_checkpoint=use_checkpoint,
)
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out

View File

@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
@ -362,6 +366,7 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@ -429,7 +434,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
@ -576,6 +581,7 @@ latent_upscale_modes = {
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
"Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
}
sd_upscalers = []

View File

@ -0,0 +1,205 @@
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, Protocol, List
def narrow_trunc(
input: Tensor,
dim: int,
start: int,
length: int
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk(Protocol):
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn(Protocol):
@staticmethod
def __call__(
query: Tensor,
key: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_tokens, k_channels_per_head = key.shape
_, _, v_channels_per_head = value.shape
def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = narrow_trunc(
key,
1,
chunk_idx,
kv_chunk_size
)
value_chunk = narrow_trunc(
value,
1,
chunk_idx,
kv_chunk_size
)
return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs
all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key: keys for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, k_tokens, _ = key.shape
scale = q_channels_per_head ** -0.5
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
def get_query_chunk(chunk_idx: int) -> Tensor:
return narrow_trunc(
query,
1,
chunk_idx,
min(query_chunk_size, q_tokens)
)
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)
if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key=key,
value=value,
)
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res

View File

@ -0,0 +1,24 @@
import datetime
import json
import os
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
def save_settings_to_file(log_directory, all_params):
now = datetime.datetime.now()
params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
keys = saved_params_all
if all_params.get('preview_from_txt2img'):
keys = keys | saved_params_previews
params.update({k: v for k, v in all_params.items() if k in keys})
filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
with open(os.path.join(log_directory, filename), "w") as file:
json.dump(params, file, indent=4)

View File

@ -1,6 +1,7 @@
import os
import sys
import traceback
import inspect
import torch
import tqdm
@ -17,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
from modules.textual_inversion.logging import save_settings_to_file
class Embedding:
def __init__(self, vec, name, step=None):
@ -76,7 +79,6 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
@ -149,19 +151,20 @@ class EmbeddingDatabase:
else:
self.skipped_embeddings[name] = embedding
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
for root, dirs, fns in os.walk(self.embeddings_dir):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
if os.stat(fullfn).st_size == 0:
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@ -229,6 +232,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
@ -292,8 +296,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
@ -307,6 +311,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)

View File

@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.ui_components import FormRow, FormGroup, ToolButton
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@ -256,6 +256,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
if not enable:
return ""
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
with devices.autocast():
p.init([""], [0], [0])
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
@ -368,7 +382,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
@ -435,11 +449,9 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=1, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style.save_to_config = True
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
prompt_style2.save_to_config = True
return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
@ -550,6 +562,8 @@ Requested path was: {f}
os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
elif "microsoft-standard-WSL2" in platform.uname().release:
sp.Popen(["wsl-open", path])
else:
sp.Popen(["xdg-open", path])
@ -636,7 +650,6 @@ def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
sampler_index.save_to_config = True
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
@ -707,6 +720,7 @@ def create_ui():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
@ -730,6 +744,17 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
hr_resolution_preview_args = dict(
fn=calc_resolution_hires,
inputs=hr_resolution_preview_inputs,
outputs=[hr_final_resolution],
show_progress=False
)
for input in hr_resolution_preview_inputs:
input.change(**hr_resolution_preview_args)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@ -791,6 +816,7 @@ def create_ui():
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
show_progress = False,
)
txt2img_paste_fields = [
@ -1792,7 +1818,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@ -1813,11 +1839,8 @@ def create_ui():
if type(x) == gr.Number:
apply_field(x, 'value')
# Since there are many dropdowns that shouldn't be saved,
# we only mark dropdowns that should be saved.
if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
if type(x) == gr.Dropdown:
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
apply_field(x, 'visible')
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")

View File

@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent):
def get_block_name(self):
return "group"
class FormHTML(gr.HTML, gr.components.FormComponent):
"""Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self):
return "html"

View File

@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url, hide_tags):
def install_extension_from_index(url, hide_tags, sort_column):
ext_table, message = install_extension_from_url(None, url)
code, _ = refresh_available_extensions_from_data(hide_tags)
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ext_table, message
def refresh_available_extensions(url, hide_tags):
def refresh_available_extensions(url, hide_tags, sort_column):
global available_extensions
import urllib.request
@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags):
available_extensions = json.loads(text)
code, tags = refresh_available_extensions_from_data(hide_tags)
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
def refresh_available_extensions_for_tags(hide_tags):
code, _ = refresh_available_extensions_from_data(hide_tags)
def refresh_available_extensions_for_tags(hide_tags, sort_column):
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ''
def refresh_available_extensions_from_data(hide_tags):
sort_ordering = [
# (reverse, order_by_function)
(True, lambda x: x.get('added', 'z')),
(False, lambda x: x.get('added', 'z')),
(False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'),
]
def refresh_available_extensions_from_data(hide_tags, sort_column):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags):
<tbody>
"""
for ext in extlist:
sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
added = ext.get('added', 'unknown')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags):
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}</td>
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
<td>{install_code}</td>
</tr>
@ -291,25 +304,32 @@ def create_ui():
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags],
inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install, hide_tags],
inputs=[extension_to_install, hide_tags, sort_column],
outputs=[available_extensions_table, extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags],
inputs=[hide_tags, sort_column],
outputs=[available_extensions_table, install_result]
)
sort_column.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags, sort_column],
outputs=[available_extensions_table, install_result]
)

View File

@ -30,4 +30,4 @@ inflection
GitPython
torchsde
safetensors
psutil; sys_platform == 'darwin'
psutil

Binary file not shown.

Before

Width:  |  Height:  |  Size: 513 KiB

After

Width:  |  Height:  |  Size: 411 KiB

View File

@ -555,7 +555,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
/* Extensions */
#tab_extensions table{
#tab_extensions table``{
border-collapse: collapse;
}
@ -581,6 +581,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
font-size: 95%;
}
#available_extensions .info{
margin: 0;
}
#available_extensions .date_added{
opacity: 0.85;
font-size: 90%;
}
#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{
min-width: auto;
padding-left: 0.5em;
@ -633,6 +642,23 @@ footer {
opacity: 0.85;
}
#txtimg_hr_finalres{
min-height: 0 !important;
padding: .625rem .75rem;
margin-left: -0.75em
}
#txtimg_hr_finalres .resolution{
font-weight: bold;
}
#txt2img_checkboxes > div > div{
flex: 0;
white-space: nowrap;
min-width: auto;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running

Binary file not shown.

Before

Width:  |  Height:  |  Size: 329 KiB

View File

@ -4,7 +4,7 @@ import threading
import time
import importlib
import signal
import threading
import re
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@ -13,6 +13,11 @@ from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
import modules.extras
@ -182,12 +187,14 @@ def webui():
sd_samplers.set_samplers()
modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()
modules.scripts.reload_scripts()
modules.script_callbacks.model_loaded_callback(shared.sd_model)
modelloader.load_upscalers()
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: