Merge branch 'dev' into dev

This commit is contained in:
AUTOMATIC1111 2024-07-06 09:46:57 +03:00 committed by GitHub
commit 9414309f15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 2133 additions and 82 deletions

View File

@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- GFPGAN - https://github.com/TencentARC/GFPGAN.git

View File

@ -0,0 +1,5 @@
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null

View File

@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
for name, module in cond_stage_model.named_modules():
network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name

View File

@ -26,6 +26,14 @@ function selected_gallery_index() {
return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
}
function gallery_container_buttons(gallery_container) {
return gradioApp().querySelectorAll(`#${gallery_container} .thumbnail-item.thumbnail-small`);
}
function selected_gallery_index_id(gallery_container) {
return Array.from(gallery_container_buttons(gallery_container)).findIndex(elem => elem.classList.contains('selected'));
}
function extract_image_from_gallery(gallery) {
if (gallery.length == 0) {
return [null];

View File

@ -43,7 +43,7 @@ def script_name_to_index(name, scripts):
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
if config is None:
raise HTTPException(status_code=404, detail="Sampler not found")
raise HTTPException(status_code=400, detail="Sampler not found")
return name

View File

@ -30,7 +30,7 @@ parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN mod
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--max-batch-count", type=int, default=16, help="does not do anything")
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")

View File

@ -57,7 +57,7 @@ class DeepDanbooru:
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
x = torch.from_numpy(a).to(devices.device, devices.dtype)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}

View File

@ -36,13 +36,11 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
ext_filter=['.pth'],
):
if 'GFPGAN' in os.path.basename(model_path):
model = modelloader.load_spandrel_model(
return modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
expected_architecture='GFPGAN',
).model
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return model
raise ValueError("No GFPGAN model found")
def restore(self, np_image):

View File

@ -1,9 +1,12 @@
from collections import namedtuple
import torch
from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
def send_everything_to_cpu():
global module_in_gpu
@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
(sd_model, 'embedder'),
]
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl:
if hasattr(sd_model, 'medvram_fields'):
to_remain_in_cpu = sd_model.medvram_fields()
elif is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
setattr(obj, field, module)
# register hooks for those the first three models
if is_sdxl:
if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):
for module in sd_model.cond_stage_model.medvram_modules():
if isinstance(module, ModuleWithParent):
parent = module.parent
module = module.module
else:
parent = None
if module:
module.register_forward_pre_hook(send_me_to_gpu)
if parent:
parents[module] = parent
elif is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
if getattr(sd_model, 'depth_model', None) is not None:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
if getattr(sd_model, 'embedder', None) is not None:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram:

View File

@ -139,6 +139,27 @@ def load_upscalers():
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
# None: not loaded, False: failed to load, True: loaded
_spandrel_extra_init_state = None
def _init_spandrel_extra_archs() -> None:
"""
Try to initialize `spandrel_extra_archs` (exactly once).
"""
global _spandrel_extra_init_state
if _spandrel_extra_init_state is not None:
return
try:
import spandrel
import spandrel_extra_arches
spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)
_spandrel_extra_init_state = True
except Exception:
logger.warning("Failed to load spandrel_extra_arches", exc_info=True)
_spandrel_extra_init_state = False
def load_spandrel_model(
path: str | os.PathLike,
@ -148,11 +169,16 @@ def load_spandrel_model(
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
global _spandrel_extra_init_state
import spandrel
_init_spandrel_extra_archs()
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
arch = model_descriptor.architecture
if expected_architecture and arch.name != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
f"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})",
)
half = False
if prefer_half:
@ -166,6 +192,6 @@ def load_spandrel_model(
model_descriptor.model.eval()
logger.debug(
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
model_descriptor, path, device, half, dtype,
arch, path, device, half, dtype,
)
return model_descriptor

619
modules/models/sd3/mmdit.py Normal file
View File

@ -0,0 +1,619 @@
### This file contains impls for MM-DiT, the core model component of SD3
import math
from typing import Dict, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from modules.models.sd3.other_impls import attention, Mlp
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding"""
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
flatten: bool = True,
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = False,
dtype=None,
device=None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
return x
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if scaling_factor is not None:
grid = grid / scaling_factor
if offset is not None:
grid = grid - offset
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""Embeds scalar timesteps into vector representations."""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
def forward(self, t, dtype, **kwargs):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class VectorEmbedder(nn.Module):
"""Embeds a flat vector of dimension input_dim"""
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
#################################################################################
# Core DiT Model #
#################################################################################
def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: Optional[float] = None,
attn_mode: str = "xformers",
pre_only: bool = False,
qk_norm: Optional[str] = None,
rmsnorm: bool = False,
dtype=None,
device=None,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
if not pre_only:
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
assert attn_mode in self.ATTENTION_MODES
self.attn_mode = attn_mode
self.pre_only = pre_only
if qk_norm == "rms":
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
elif qk_norm == "ln":
self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)
elif qk_norm is None:
self.ln_q = nn.Identity()
self.ln_k = nn.Identity()
else:
raise ValueError(qk_norm)
def pre_attention(self, x: torch.Tensor):
B, L, C = x.shape
qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.head_dim)
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
return (q, k, v)
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
x = self.proj(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
(q, k, v) = self.pre_attention(x)
x = attention(q, k, v, self.num_heads)
x = self.post_attention(x)
return x
class RMSNorm(torch.nn.Module):
def __init__(
self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.learnable_scale = elementwise_affine
if self.learnable_scale:
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
class SwiGLUFeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float] = None,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class DismantledBlock(nn.Module):
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: str = "xformers",
qkv_bias: bool = False,
pre_only: bool = False,
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
dtype=None,
device=None,
**block_kwargs,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if not rmsnorm:
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
else:
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)
if not pre_only:
if not rmsnorm:
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
else:
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
if not pre_only:
if not swiglu:
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device)
else:
self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))
self.pre_only = pre_only
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
assert x is not None, "pre_attention called with None input"
if not self.pre_only:
if not self.scale_mod_only:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
else:
shift_msa = None
shift_mlp = None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
else:
if not self.scale_mod_only:
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
else:
shift_msa = None
scale_msa = self.adaLN_modulation(c)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, None
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
(q, k, v), intermediates = self.pre_attention(x, c)
attn = attention(q, k, v, self.attn.num_heads)
return self.post_attention(attn, *intermediates)
def block_mixing(context, x, context_block, x_block, c):
assert context is not None, "block_mixing called with None context"
context_qkv, context_intermediates = context_block.pre_attention(context, c)
x_qkv, x_intermediates = x_block.pre_attention(x, c)
o = []
for t in range(3):
o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))
q, k, v = tuple(o)
attn = attention(q, k, v, x_block.attn.num_heads)
context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
class JointBlock(nn.Module):
"""just a small wrapper to serve as a fsdp unit"""
def __init__(self, *args, **kwargs):
super().__init__()
pre_only = kwargs.pop("pre_only")
qk_norm = kwargs.pop("qk_norm", None)
self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)
self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)
def forward(self, *args, **kwargs):
return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = (
nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
if (total_out_channels is None)
else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class MMDiT(nn.Module):
"""Diffusion model with a Transformer backbone."""
def __init__(
self,
input_size: int = 32,
patch_size: int = 2,
in_channels: int = 4,
depth: int = 28,
mlp_ratio: float = 4.0,
learn_sigma: bool = False,
adm_in_channels: Optional[int] = None,
context_embedder_config: Optional[Dict] = None,
register_length: int = 0,
attn_mode: str = "torch",
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
out_channels: Optional[int] = None,
pos_embed_scaling_factor: Optional[float] = None,
pos_embed_offset: Optional[float] = None,
pos_embed_max_size: Optional[int] = None,
num_patches = None,
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
dtype = None,
device = None,
):
super().__init__()
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
# apply magic --> this defines a head_size of 64
hidden_size = 64 * depth
num_heads = depth
self.num_heads = num_heads
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
if adm_in_channels is not None:
assert isinstance(adm_in_channels, int)
self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)
self.context_embedder = nn.Identity()
if context_embedder_config is not None:
if context_embedder_config["target"] == "torch.nn.Linear":
self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device)
self.register_length = register_length
if self.register_length > 0:
self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))
# num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
# just use a buffer already
if num_patches is not None:
self.register_buffer(
"pos_embed",
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
)
else:
self.pos_embed = None
self.joint_blocks = nn.ModuleList(
[
JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)
for i in range(depth)
]
)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)
def cropped_pos_embed(self, hw):
assert self.pos_embed_max_size is not None
p = self.x_embedder.patch_size[0]
h, w = hw
# patched size
h = h // p
w = w // p
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2
spatial_pos_embed = rearrange(
self.pos_embed,
"1 (h w) c -> 1 h w c",
h=self.pos_embed_max_size,
w=self.pos_embed_max_size,
)
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
return spatial_pos_embed
def unpatchify(self, x, hw=None):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
if hw is None:
h = w = int(x.shape[1] ** 0.5)
else:
h, w = hw
h = h // p
w = w // p
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.register_length > 0:
context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)
# context is B, L', D
# x is B, L, D
for block in self.joint_blocks:
context, x = block(context, x, c=c_mod)
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
return x
def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
hw = x.shape[-2:]
x = self.x_embedder(x) + self.cropped_pos_embed(hw)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None:
y = self.y_embedder(y) # (N, D)
c = c + y # (N, D)
context = self.context_embedder(context)
x = self.forward_core_with_concat(x, c, context)
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
return x

View File

@ -0,0 +1,508 @@
### This file contains impls for underlying related models (CLIP, T5, etc)
import torch
import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
### Core/Utility
#################################################################################################
class AutocastLinear(nn.Linear):
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
This is different from torch.autocast in a way that float16 layer processing float32 input
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
if you do it in full float16 (returning almost all zeros in the final output).
"""
def forward(self, x):
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
#################################################################################################
### CLIP
#################################################################################################
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
def forward(self, x, mask=None):
x += self.self_attn(self.layer_norm1(x), mask)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, layer in enumerate(self.layers):
x = layer(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
x = self.text_model(*args, **kwargs)
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
class SDTokenizer:
def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):
self.tokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer('')["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize_with_weights(self, text:str):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0))
to_tokenize = text.replace("\n", " ").split(' ')
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
batch.append((self.end_token, 1.0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
return [batch]
class SDXLClipGTokenizer(SDTokenizer):
def __init__(self, tokenizer):
super().__init__(pad_with_end=False, tokenizer=tokenizer)
class SD3Tokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text:str):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text)
out["l"] = self.clip_l.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
return out
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = [a[0] for a in token_weight_pairs[0]]
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled
output = [out[0:1]]
return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.set_clip_options({"layer": layer_idx})
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
self.layer_idx = layer_idx
def forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)
outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
pooled_output = None
if len(outputs) >= 3:
if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:
pooled_output = outputs[3].float()
elif outputs[2] is not None:
pooled_output = outputs[2].float()
return z.float(), pooled_output
class SDXLClipG(SDClipModel):
"""Wraps the CLIP-G model into the SD-CLIP-Model interface"""
def __init__(self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
class T5XXLModel(SDClipModel):
"""Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
#################################################################################################
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def __init__(self):
super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear
x = self.wo(x)
return x
class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x):
forwarded_states = self.layer_norm(x)
forwarded_states = self.DenseReluDense(forwarded_states)
x += forwarded_states
return x
class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=True,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(self, x, past_bias=None):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
mask = past_bias
else:
mask = None
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
return self.o(out), past_bias
class T5LayerSelfAttention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, x, past_bias=None):
output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
x += output
return x, past_bias
class T5Block(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
self.layer = torch.nn.ModuleList()
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))
self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
def forward(self, x, past_bias=None):
x, past_bias = self.layer[0](x, past_bias)
x = self.layer[-1](x)
return x, past_bias
class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):
super().__init__()
self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
x = self.embed_tokens(input_ids).to(torch.float32) # needs float32 or else T5 returns all zeroes
past_bias = None
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.final_layer_norm(intermediate)
return x, intermediate
class T5(torch.nn.Module):
def __init__(self, config_dict, dtype, device):
super().__init__()
self.num_layers = config_dict["num_layers"]
self.encoder = T5Stack(self.num_layers, config_dict["d_model"], config_dict["d_model"], config_dict["d_ff"], config_dict["num_heads"], config_dict["vocab_size"], dtype, device)
self.dtype = dtype
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, embeddings):
self.encoder.embed_tokens = embeddings
def forward(self, *args, **kwargs):
return self.encoder(*args, **kwargs)

View File

@ -0,0 +1,218 @@
import os
import safetensors
import torch
import typing
from transformers import CLIPTokenizer, T5TokenizerFast
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
class SafetensorsMapping(typing.Mapping):
def __init__(self, file):
self.file = file
def __len__(self):
return len(self.file.keys())
def __iter__(self):
for key in self.file.keys():
yield key
def __getitem__(self, key):
return self.file.get_tensor(key)
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
def __init__(self, clip_l, clip_g):
super().__init__()
self.clip_l = clip_l
self.clip_g = clip_g
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
empty = self.tokenizer('')["input_ids"]
self.id_start = empty[0]
self.id_end = empty[1]
self.id_pad = empty[1]
self.return_pooled = True
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def encode_with_transformers(self, tokens):
tokens_g = tokens.clone()
for batch_pos in range(tokens_g.shape[0]):
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
l_out, l_pooled = self.clip_l(tokens)
g_out, g_pooled = self.clip_g(tokens_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out.pooled = vector_out
return lg_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
class Sd3T5(torch.nn.Module):
def __init__(self, t5xxl):
super().__init__()
self.t5xxl = t5xxl
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
self.id_end = empty[0]
self.id_pad = empty[1]
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def tokenize_line(self, line, *, target_token_count=None):
if shared.opts.emphasis != "None":
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
tokens = []
multipliers = []
for text_tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
continue
tokens += text_tokens
multipliers += [weight] * len(text_tokens)
tokens += [self.id_end]
multipliers += [1.0]
if target_token_count is not None:
if len(tokens) < target_token_count:
tokens += [self.id_pad] * (target_token_count - len(tokens))
multipliers += [1.0] * (target_token_count - len(tokens))
else:
tokens = tokens[0:target_token_count]
multipliers = multipliers[0:target_token_count]
return tokens, multipliers
def forward(self, texts, *, token_count):
if not self.t5xxl or not shared.opts.sd3_enable_t5:
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
tokens_batch = []
for text in texts:
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
tokens_batch.append(tokens)
t5_out, t5_pooled = self.t5xxl(tokens_batch)
return t5_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
class SD3Cond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
else:
self.t5xxl = None
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
self.model_t5 = Sd3T5(self.t5xxl)
def forward(self, prompts: list[str]):
with devices.without_autocast():
lg_out, vector_out = self.model_lg(prompts)
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
return {
'crossattn': lgt_out,
'vector': vector_out,
}
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
def get_token_count(self, text):
_, token_count = self.model_lg.process_texts([text])
return token_count
def get_target_prompt_token_count(self, token_count):
return self.model_lg.get_target_prompt_token_count(token_count)

View File

@ -0,0 +1,373 @@
### Impls of the SD3 core diffusion model and VAE
import torch
import math
import einops
from modules.models.sd3.mmdit import MMDiT
from PIL import Image
#################################################################################################
### MMDiT Model Wrapping
#################################################################################################
class ModelSamplingDiscreteFlow(torch.nn.Module):
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift=1.0):
super().__init__()
self.shift = shift
timesteps = 1000
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return sigma * noise + (1.0 - sigma) * latent_image
class BaseModel(torch.nn.Module):
"""Wrapper around the core MM-DiT model"""
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
super().__init__()
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {
"in_features": context_shape[1],
"out_features": context_shape[0]
}
}
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
def get_dtype(self):
return self.diffusion_model.dtype
class CFGDenoiser(torch.nn.Module):
"""Helper for applying CFG Scaling to diffusion outputs"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x, timestep, cond, uncond, cond_scale):
# Run cond and uncond in a batch together
batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), y=torch.cat([cond["y"], uncond["y"]]))
# Then split and apply CFG Scaling
pos_out, neg_out = batched.chunk(2)
scaled = neg_out + (pos_out - neg_out) * cond_scale
return scaled
class SD3LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
def decode_latent_to_preview(self, x0):
"""Quick RGB approximate preview of sd3 latents"""
factors = torch.tensor([
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]
], device="cpu")
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
#################################################################################################
### K-Diffusion Sampling
#################################################################################################
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
return x[(...,) + (None,) * dims_to_append]
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.float16)
def sample_euler(model, x, sigmas, extra_args=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in range(len(sigmas) - 1):
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
#################################################################################################
### VAE
#################################################################################################
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
class ResnetBlock(torch.nn.Module):
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
else:
self.nin_shortcut = None
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
hidden = x
hidden = self.norm1(hidden)
hidden = self.swish(hidden)
hidden = self.conv1(hidden)
hidden = self.norm2(hidden)
hidden = self.swish(hidden)
hidden = self.conv2(hidden)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + hidden
class AttnBlock(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
def forward(self, x):
hidden = self.norm(x)
q = self.q(hidden)
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
return x + hidden
class Downsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
def forward(self, x):
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class VAEEncoder(torch.nn.Module):
def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = torch.nn.ModuleList()
for i_level in range(self.num_resolutions):
block = torch.nn.ModuleList()
attn = torch.nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for _ in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
down = torch.nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, dtype=dtype, device=device)
self.down.append(down)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = self.swish(h)
h = self.conv_out(h)
return h
class VAEDecoder(torch.nn.Module):
def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# upsampling
self.up = torch.nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
up = torch.nn.Module()
up.block = block
if i_level != 0:
up.upsample = Upsample(block_in, dtype=dtype, device=device)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, z):
# z to block_in
hidden = self.conv_in(z)
# middle
hidden = self.mid.block_1(hidden)
hidden = self.mid.attn_1(hidden)
hidden = self.mid.block_2(hidden)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden = self.up[i_level].block[i_block](hidden)
if i_level != 0:
hidden = self.up[i_level].upsample(hidden)
# end
hidden = self.norm_out(hidden)
hidden = self.swish(hidden)
hidden = self.conv_out(hidden)
return hidden
class SDVAE(torch.nn.Module):
def __init__(self, dtype=torch.float32, device=None):
super().__init__()
self.encoder = VAEEncoder(dtype=dtype, device=device)
self.decoder = VAEDecoder(dtype=dtype, device=device)
@torch.autocast("cuda", dtype=torch.float16)
def decode(self, latent):
return self.decoder(latent)
@torch.autocast("cuda", dtype=torch.float16)
def encode(self, image):
hidden = self.encoder(image)
mean, logvar = torch.chunk(hidden, 2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)

View File

@ -0,0 +1,84 @@
import contextlib
import torch
import k_diffusion
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
from modules.models.sd3.sd3_cond import SD3Cond
from modules import shared, devices
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model
def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)
class SD3Inferencer(torch.nn.Module):
def __init__(self, state_dict, shift=3, use_ema=False):
super().__init__()
self.shift = shift
with torch.no_grad():
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.text_encoders = SD3Cond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
self.model.conditioning_key = "crossattn"
self.latent_format = SD3LatentFormat()
self.latent_channels = 16
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()
def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)
def get_first_stage_encoding(self, x):
return x
def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'text_encoders'),
(self, 'model'),
]
def add_noise_to_latent(self, x, noise, amount):
return x * (1 - amount) + noise * amount
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16

View File

@ -884,6 +884,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
if hasattr(shared.sd_model, 'fix_dimensions'):
p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
p.sd_model_hash = shared.sd_model.sd_model_hash
p.sd_vae_name = sd_vae.get_loaded_vae_name()
@ -942,7 +945,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@ -1736,10 +1740,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = latmask[0]
if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:

View File

@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class DictWithShape(dict):
def __init__(self, x, shape):
def __init__(self, x, shape=None):
super().__init__()
self.update(x)

View File

@ -325,7 +325,10 @@ class StableDiffusionModelHijack:
if self.clip is None:
return "-", "-"
_, token_count = self.clip.process_texts([text])
if hasattr(self.clip, 'get_token_count'):
token_count = self.clip.get_token_count(text)
else:
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)

View File

@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
class TextConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.wrapped = wrapped
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.hijack = sd_hijack.model_hijack
self.chunk_length = 75
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.legacy_ucg_val = None
self.is_trainable = False
self.input_key = 'txt'
self.return_pooled = False
self.comma_token = None
self.id_start = None
self.id_end = None
self.id_pad = None
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)
batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {}
@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if getattr(self.wrapped, 'return_pooled', False):
if self.return_pooled:
return torch.hstack(zs), zs[0].pooled
else:
return torch.hstack(zs)
@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return z
class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
super().__init__()
self.hijack = hijack
self.wrapped = wrapped
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
self.legacy_ucg_val = None # for sgm codebase
def forward(self, texts):
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)
return super().forward(texts)
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)

View File

@ -1,7 +1,9 @@
import collections
import importlib
import os
import sys
import threading
import enum
import torch
import re
@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
from modules.shared import opts
@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
class ModelType(enum.Enum):
SD1 = 1
SD2 = 2
SDXL = 3
SSD = 4
SD3 = 5
def replace_key(d, key, new_key, value):
keys = list(d.keys())
@ -368,6 +376,37 @@ def check_fp8(model):
return enable_fp8
def set_model_type(model, state_dict):
model.is_sd1 = False
model.is_sd2 = False
model.is_sdxl = False
model.is_ssd = False
model.is_sd3 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
model.is_ssd = True
model.model_type = ModelType.SSD
else:
model.model_type = ModelType.SDXL
elif hasattr(model.cond_stage_model, 'model'):
model.is_sd2 = True
model.model_type = ModelType.SD2
else:
model.is_sd1 = True
model.model_type = ModelType.SD1
def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
@ -382,10 +421,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
set_model_type(model, state_dict)
set_model_fields(model)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@ -396,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
if hasattr(model, "after_load_weights"):
model.after_load_weights(state_dict)
del state_dict
# Set is_sdxl_inpaint flag.
@ -552,8 +596,7 @@ def patch_given_betas():
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config):
def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@ -563,8 +606,9 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
if hasattr(sd_config.model.params, 'first_stage_config'):
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
@ -580,6 +624,7 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
@ -679,11 +724,15 @@ def get_empty_cond(sd_model):
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
if hasattr(sd_model, 'conditioner'):
if hasattr(sd_model, 'get_learned_conditioning'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
else:
return sd_model.cond_stage_model([""])
d = sd_model.cond_stage_model([""])
if isinstance(d, dict):
d = d['crossattn']
return d
def send_model_to_cpu(m):
@ -715,6 +764,25 @@ def send_model_to_trash(m):
devices.torch_gc()
def instantiate_from_config(config, state_dict=None):
constructor = get_obj_from_str(config["target"])
params = {**config.get("params", {})}
if state_dict and "state_dict" in params and params["state_dict"] is None:
params["state_dict"] = state_dict
return constructor(**params)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@ -739,7 +807,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config)
repair_config(sd_config, state_dict)
timer.record("load config")
@ -749,7 +817,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_model = instantiate_from_config(sd_config.model, state_dict)
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
@ -758,7 +826,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_model = instantiate_from_config(sd_config.model, state_dict)
sd_model.used_config = checkpoint_config
@ -775,6 +843,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
send_model_to_device(sd_model)

View File

@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
"""
@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18

View File

@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
is_sd1: bool
"""True if the model's architecture is SD 1.x"""
is_sd3: bool
"""True if the model's architecture is SD 3"""
latent_channels: int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""

View File

@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000
@ -246,7 +246,7 @@ class Sampler:
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
self.p = None
self.model_wrap_cfg = None

View File

@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
if denoiser_constructor is not None:
self.model_wrap = denoiser_constructor()
else:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
return self.model_wrap
@ -120,7 +125,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
return sigmas.cpu()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@ -128,7 +133,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
sigmas = self.get_sigmas(p, steps)
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
if hasattr(shared.sd_model, 'add_noise_to_latent'):
xi = shared.sd_model.add_noise_to_latent(x, noise, sigma_sched[0])
else:
xi = x + noise * sigma_sched[0]
if opts.img2img_extra_noise > 0:
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise

View File

@ -76,6 +76,7 @@ def kl_optimal(n, sigma_min, sigma_max, device):
sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
return sigmas
def simple_scheduler(n, sigma_min, sigma_max, inner_model, device):
sigs = []
ss = len(inner_model.sigmas) / n
@ -85,6 +86,35 @@ def simple_scheduler(n, sigma_min, sigma_max, inner_model, device):
return torch.FloatTensor(sigs).to(device)
def normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False):
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
if sgm:
timesteps = torch.linspace(start, end, n + 1)[:-1]
else:
timesteps = torch.linspace(start, end, n)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(inner_model.t_to_sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)
def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
sigs = []
ss = max(len(inner_model.sigmas) // n, 1)
x = 1
while x < len(inner_model.sigmas):
sigs += [float(inner_model.sigmas[x])]
x += ss
sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)
schedulers = [
Scheduler('automatic', 'Automatic', None),
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
@ -95,6 +125,8 @@ schedulers = [
Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),
Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
]
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

View File

@ -8,9 +8,9 @@ sd_vae_approx_models = {}
class VAEApprox(nn.Module):
def __init__(self):
def __init__(self, latent_channels=4):
super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7))
self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3))
@ -40,7 +40,13 @@ def download_model(model_path, model_url):
def model():
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
if shared.sd_model.is_sd3:
model_name = "vaeapprox-sd3.pt"
elif shared.sd_model.is_sdxl:
model_name = "vaeapprox-sdxl.pt"
else:
model_name = "model.pt"
loaded_model = sd_vae_approx_models.get(model_name)
if loaded_model is None:
@ -52,7 +58,7 @@ def model():
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
loaded_model = VAEApprox()
loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
@ -64,7 +70,18 @@ def model():
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
if shared.sd_model.is_sdxl:
if shared.sd_model.is_sd3:
coeffs = [
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
]
elif shared.sd_model.is_sdxl:
coeffs = [
[ 0.3448, 0.4168, 0.4395],
[-0.1953, -0.0290, 0.0250],

View File

@ -34,9 +34,9 @@ class Block(nn.Module):
return self.fuse(self.conv(x) + self.skip(x))
def decoder():
def decoder(latent_channels=4):
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@ -44,13 +44,13 @@ def decoder():
)
def encoder():
def encoder(latent_channels=4):
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
conv(64, latent_channels),
)
@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"):
def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.decoder = decoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth"):
def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = encoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@ -87,7 +95,13 @@ def download_model(model_path, model_url):
def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
model_name = "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:
@ -106,7 +120,13 @@ def decoder_model():
def encoder_model():
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else:
model_name = "taesd_encoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:

View File

@ -191,6 +191,10 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>

View File

@ -194,7 +194,7 @@ class UserMetadataEditor:
def setup_ui(self, gallery):
self.button_replace_preview.click(
fn=self.save_preview,
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
_js=f"function(x, y, z){{return [selected_gallery_index_id('{self.tabname + '_gallery_container'}'), y, z]}}",
inputs=[self.edit_name_input, gallery, self.edit_name_input],
outputs=[self.html_preview, self.html_status]
).then(

View File

@ -18,6 +18,7 @@ omegaconf
open-clip-torch
piexif
protobuf==3.20.0
psutil
pytorch_lightning
requests

View File

@ -18,12 +18,14 @@ numpy==1.26.2
omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3
protobuf==3.20.0
psutil==5.9.5
pytorch_lightning==1.9.4
resize-right==0.0.2
safetensors==0.4.2
scikit-image==0.21.0
spandrel==0.1.6
spandrel==0.3.4
spandrel-extra-arches==0.1.1
tomesd==0.1.3
torch
torchdiffeq==0.2.3