initial SD3 support

This commit is contained in:
AUTOMATIC1111 2024-06-16 08:04:31 +03:00
parent a7116aa9a1
commit 5b2a60b8e2
14 changed files with 333 additions and 44 deletions

View File

@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git

View File

@ -0,0 +1,5 @@
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null

View File

@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
module.network_layer_name = network_name module.network_layer_name = network_name
else: else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
for name, module in cond_stage_model.named_modules():
network_name = name.replace(".", "_") network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
module.network_layer_name = network_name module.network_layer_name = network_name

View File

@ -6,7 +6,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange, repeat from einops import rearrange, repeat
from other_impls import attention, Mlp from modules.models.sd3.other_impls import attention, Mlp
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding""" """ 2D Image to Patch Embedding"""

View File

@ -1,7 +1,7 @@
### Impls of the SD3 core diffusion model and VAE ### Impls of the SD3 core diffusion model and VAE
import torch, math, einops import torch, math, einops
from mmdit import MMDiT from modules.models.sd3.mmdit import MMDiT
from PIL import Image from PIL import Image
@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
"""Wrapper around the core MM-DiT model""" """Wrapper around the core MM-DiT model"""
def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""): def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
super().__init__() super().__init__()
# Important configuration values can be quickly determined by checking shapes in the source file # Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2] patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64 depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1] num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches)) pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1] adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
context_embedder_config = { context_embedder_config = {
"target": "torch.nn.Linear", "target": "torch.nn.Linear",
"params": { "params": {

View File

@ -0,0 +1,166 @@
import contextlib
import os
from typing import Mapping
import safetensors
import torch
import k_diffusion
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
from modules import shared, modelloader, devices
CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class SafetensorsMapping(Mapping):
def __init__(self, file):
self.file = file
def __len__(self):
return len(self.file.keys())
def __iter__(self):
for key in self.file.keys():
yield key
def __getitem__(self, key):
return self.file.get_tensor(key)
class SD3Cond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
self.weights_loaded = False
def forward(self, prompts: list[str]):
res = []
for prompt in prompts:
tokens = self.tokenizer.tokenize_with_weights(prompt)
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
res.append({
'crossattn': lgt_out[0].to(devices.device),
'vector': vector_out[0].to(devices.device),
})
return res
def load_weights(self):
if self.weights_loaded:
return
clip_path = os.path.join(shared.models_path, "CLIP")
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
self.weights_loaded = True
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model
def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)
class SD3Inferencer(torch.nn.Module):
def __init__(self, state_dict, shift=3, use_ema=False):
super().__init__()
self.shift = shift
with torch.no_grad():
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.cond_stage_model = SD3Cond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
self.model.conditioning_key = "crossattn"
self.latent_format = SD3LatentFormat()
self.latent_channels = 16
def after_load_weights(self):
self.cond_stage_model.load_weights()
def ema_scope(self):
return contextlib.nullcontext()
def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)
def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)

View File

@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)

View File

@ -1,7 +1,9 @@
import collections import collections
import importlib
import os import os
import sys import sys
import threading import threading
import enum
import torch import torch
import re import re
@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
from urllib import request from urllib import request
import ldm.modules.midas as midas import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer from modules.timer import Timer
from modules.shared import opts from modules.shared import opts
@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
class ModelType(enum.Enum):
SD1 = 1
SD2 = 2
SDXL = 3
SSD = 4
SD3 = 5
def replace_key(d, key, new_key, value): def replace_key(d, key, new_key, value):
keys = list(d.keys()) keys = list(d.keys())
@ -368,6 +376,36 @@ def check_fp8(model):
return enable_fp8 return enable_fp8
def set_model_type(model, state_dict):
model.is_sd1 = False
model.is_sd2 = False
model.is_sdxl = False
model.is_ssd = False
model.is_ssd3 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
model.is_ssd = True
model.model_type = ModelType.SSD
else:
model.model_type = ModelType.SDXL
elif hasattr(model.cond_stage_model, 'model'):
model.is_sd2 = True
model.model_type = ModelType.SD2
else:
model.is_sd1 = True
model.model_type = ModelType.SD1
def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash") timer.record("calculate hash")
@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None: if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner') set_model_type(model, state_dict)
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') set_model_fields(model)
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl: if model.is_sdxl:
sd_models_xl.extend_sdxl(model) sd_models_xl.extend_sdxl(model)
@ -552,8 +589,7 @@ def patch_given_betas():
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config): def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
@ -563,6 +599,7 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True sd_config.model.params.unet_config.params.use_fp16 = True
if hasattr(sd_config.model.params, 'first_stage_config'):
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
@ -580,6 +617,7 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_checkpoint = False sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod): def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt() alphas_bar_sqrt = alphas_cumprod.sqrt()
@ -715,6 +753,25 @@ def send_model_to_trash(m):
devices.torch_gc() devices.torch_gc()
def instantiate_from_config(config, state_dict=None):
constructor = get_obj_from_str(config["target"])
params = {**config.get("params", {})}
if state_dict and "state_dict" in params and params["state_dict"] is None:
params["state_dict"] = state_dict
return constructor(**params)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("find config") timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config) sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config) repair_config(sd_config, state_dict)
timer.record("load config") timer.record("load config")
@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try: try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta(): with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model, state_dict)
except Exception as e: except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True) errors.display(e, "creating model quickly", full_traceback=True)
@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta(): with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model, state_dict)
sd_model.used_config = checkpoint_config sd_model.used_config = checkpoint_config
@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if hasattr(sd_model, "after_load_weights"):
sd_model.after_load_weights()
timer.record("load weights from state dict") timer.record("load weights from state dict")
send_model_to_device(sd_model) send_model_to_device(sd_model)

View File

@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict): def is_using_v_parameterization_for_sd2(state_dict):
""" """
@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9: if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting return config_sdxl_inpainting
else: else:
return config_sdxl return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8: if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18 return config_alt_diffusion_m18

View File

@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
is_sd1: bool is_sd1: bool
"""True if the model's architecture is SD 1.x""" """True if the model's architecture is SD 1.x"""
is_sd3: bool
"""True if the model's architecture is SD 3"""
latent_channels: int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""

View File

@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else: else:
if model is None: if model is None:
model = shared.sd_model model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample return x_sample
@ -246,7 +246,7 @@ class Sampler:
self.eta_infotext_field = 'Eta' self.eta_infotext_field = 'Eta'
self.eta_default = 1.0 self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
self.p = None self.p = None
self.model_wrap_cfg = None self.model_wrap_cfg = None

View File

@ -53,6 +53,11 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property @property
def inner_model(self): def inner_model(self):
if self.model_wrap is None: if self.model_wrap is None:
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
if denoiser_constructor is not None:
self.model_wrap = denoiser_constructor()
else:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)

View File

@ -8,9 +8,9 @@ sd_vae_approx_models = {}
class VAEApprox(nn.Module): class VAEApprox(nn.Module):
def __init__(self): def __init__(self, latent_channels=4):
super(VAEApprox, self).__init__() super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7)) self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5)) self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3)) self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3)) self.conv4 = nn.Conv2d(32, 64, (3, 3))
@ -40,7 +40,13 @@ def download_model(model_path, model_url):
def model(): def model():
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" if shared.sd_model.is_sd3:
model_name = "vaeapprox-sd3.pt"
elif shared.sd_model.is_sdxl:
model_name = "vaeapprox-sdxl.pt"
else:
model_name = "model.pt"
loaded_model = sd_vae_approx_models.get(model_name) loaded_model = sd_vae_approx_models.get(model_name)
if loaded_model is None: if loaded_model is None:
@ -52,7 +58,7 @@ def model():
model_path = os.path.join(paths.models_path, "VAE-approx", model_name) model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
loaded_model = VAEApprox() loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval() loaded_model.eval()
loaded_model.to(devices.device, devices.dtype) loaded_model.to(devices.device, devices.dtype)
@ -64,7 +70,18 @@ def model():
def cheap_approximation(sample): def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
if shared.sd_model.is_sdxl: if shared.sd_model.is_sd3:
coeffs = [
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
]
elif shared.sd_model.is_sdxl:
coeffs = [ coeffs = [
[ 0.3448, 0.4168, 0.4395], [ 0.3448, 0.4168, 0.4395],
[-0.1953, -0.0290, 0.0250], [-0.1953, -0.0290, 0.0250],

View File

@ -34,9 +34,9 @@ class Block(nn.Module):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
def decoder(): def decoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(), Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
@ -44,13 +44,13 @@ def decoder():
) )
def encoder(): def encoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
conv(3, 64), Block(64, 64), conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4), conv(64, latent_channels),
) )
@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"): def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.decoder = decoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict( self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth"): def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.encoder = encoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict( self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
@ -87,7 +95,13 @@ def download_model(model_path, model_url):
def decoder_model(): def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
model_name = "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name) loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None: if loaded_model is None:
@ -106,7 +120,13 @@ def decoder_model():
def encoder_model(): def encoder_model():
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else:
model_name = "taesd_encoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name) loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None: if loaded_model is None: