From 5b2a60b8e2b7fb1221359047cbe9bc1f6cf0c51d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jun 2024 08:04:31 +0300 Subject: [PATCH] initial SD3 support --- README.md | 2 +- configs/sd3-inference.yaml | 5 + extensions-builtin/Lora/networks.py | 4 +- modules/models/sd3/mmdit.py | 3 +- modules/models/sd3/sd3_impls.py | 14 +-- modules/models/sd3/sd3_model.py | 166 ++++++++++++++++++++++++++++ modules/processing.py | 3 +- modules/sd_models.py | 87 ++++++++++++--- modules/sd_models_config.py | 7 +- modules/sd_models_types.py | 6 + modules/sd_samplers_common.py | 4 +- modules/sd_samplers_kdiffusion.py | 9 +- modules/sd_vae_approx.py | 27 ++++- modules/sd_vae_taesd.py | 40 +++++-- 14 files changed, 333 insertions(+), 44 deletions(-) create mode 100644 configs/sd3-inference.yaml create mode 100644 modules/models/sd3/sd3_model.py diff --git a/README.md b/README.md index bc08e7ad1..fc582e15c 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ## Credits Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. -- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers +- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref - k-diffusion - https://github.com/crowsonkb/k-diffusion.git - Spandrel - https://github.com/chaiNNer-org/spandrel implementing - GFPGAN - https://github.com/TencentARC/GFPGAN.git diff --git a/configs/sd3-inference.yaml b/configs/sd3-inference.yaml new file mode 100644 index 000000000..bccb69d2e --- /dev/null +++ b/configs/sd3-inference.yaml @@ -0,0 +1,5 @@ +model: + target: modules.models.sd3.sd3_model.SD3Inferencer + params: + shift: 3 + state_dict: null diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 8869d2c82..63e8c9465 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping[network_name] = module module.network_layer_name = network_name else: - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model) + + for name, module in cond_stage_model.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py index 6d8b65bdf..5ec73c054 100644 --- a/modules/models/sd3/mmdit.py +++ b/modules/models/sd3/mmdit.py @@ -6,7 +6,8 @@ import numpy as np import torch import torch.nn as nn from einops import rearrange, repeat -from other_impls import attention, Mlp +from modules.models.sd3.other_impls import attention, Mlp + class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding""" diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py index 6e9d0a4db..91dad66d0 100644 --- a/modules/models/sd3/sd3_impls.py +++ b/modules/models/sd3/sd3_impls.py @@ -1,7 +1,7 @@ ### Impls of the SD3 core diffusion model and VAE import torch, math, einops -from mmdit import MMDiT +from modules.models.sd3.mmdit import MMDiT from PIL import Image @@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module): class BaseModel(torch.nn.Module): """Wrapper around the core MM-DiT model""" - def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""): + def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""): super().__init__() # Important configuration values can be quickly determined by checking shapes in the source file # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) - patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2] - depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64 - num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1] + patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2] + depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 + num_patches = state_dict[f"{prefix}pos_embed"].shape[1] pos_embed_max_size = round(math.sqrt(num_patches)) - adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1] - context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape + adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1] + context_shape = state_dict[f"{prefix}context_embedder.weight"].shape context_embedder_config = { "target": "torch.nn.Linear", "params": { diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py new file mode 100644 index 000000000..8b8285244 --- /dev/null +++ b/modules/models/sd3/sd3_model.py @@ -0,0 +1,166 @@ +import contextlib +import os +from typing import Mapping + +import safetensors +import torch + +import k_diffusion +from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer +from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat + +from modules import shared, modelloader, devices + +CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors" +CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, +} + +CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors" +CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + +T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors" +T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, +} + + +class SafetensorsMapping(Mapping): + def __init__(self, file): + self.file = file + + def __len__(self): + return len(self.file.keys()) + + def __iter__(self): + for key in self.file.keys(): + yield key + + def __getitem__(self, key): + return self.file.get_tensor(key) + + +class SD3Cond(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = SD3Tokenizer() + + with torch.no_grad(): + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32) + + self.weights_loaded = False + + def forward(self, prompts: list[str]): + res = [] + + for prompt in prompts: + tokens = self.tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"]) + g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"]) + t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"]) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + lgt_out = torch.cat([lg_out, t5_out], dim=-2) + vector_out = torch.cat((l_pooled, g_pooled), dim=-1) + + res.append({ + 'crossattn': lgt_out[0].to(devices.device), + 'vector': vector_out[0].to(devices.device), + }) + + return res + + def load_weights(self): + if self.weights_loaded: + return + + clip_path = os.path.join(shared.models_path, "CLIP") + + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") + with safetensors.safe_open(clip_g_file, framework="pt") as file: + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) + + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") + with safetensors.safe_open(t5_file, framework="pt") as file: + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + + self.weights_loaded = True + + def encode_embedding_init_text(self, init_text, nvpt): + return torch.tensor([[0]], device=devices.device) # XXX + + +class SD3Denoiser(k_diffusion.external.DiscreteSchedule): + def __init__(self, inner_model, sigmas): + super().__init__(sigmas, quantize=shared.opts.enable_quantization) + self.inner_model = inner_model + + def forward(self, input, sigma, **kwargs): + return self.inner_model.apply_model(input, sigma, **kwargs) + + +class SD3Inferencer(torch.nn.Module): + def __init__(self, state_dict, shift=3, use_ema=False): + super().__init__() + + self.shift = shift + + with torch.no_grad(): + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype) + self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) + self.first_stage_model.dtype = self.model.diffusion_model.dtype + + self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) + + self.cond_stage_model = SD3Cond() + self.cond_stage_key = 'txt' + + self.parameterization = "eps" + self.model.conditioning_key = "crossattn" + + self.latent_format = SD3LatentFormat() + self.latent_channels = 16 + + def after_load_weights(self): + self.cond_stage_model.load_weights() + + def ema_scope(self): + return contextlib.nullcontext() + + def get_learned_conditioning(self, batch: list[str]): + return self.cond_stage_model(batch) + + def apply_model(self, x, t, cond): + return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) + + def decode_first_stage(self, latent): + latent = self.latent_format.process_out(latent) + return self.first_stage_model.decode(latent) + + def encode_first_stage(self, image): + latent = self.first_stage_model.encode(image) + return self.latent_format.process_in(latent) + + def create_denoiser(self): + return SD3Denoiser(self, self.model.model_sampling.sigmas) diff --git a/modules/processing.py b/modules/processing.py index 79a3f0a72..d32a1811e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C) + p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) diff --git a/modules/sd_models.py b/modules/sd_models.py index af35187cd..21a98c1de 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,9 @@ import collections +import importlib import os import sys import threading +import enum import torch import re @@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig from urllib import request import ldm.modules.midas as midas -from ldm.util import instantiate_from_config - from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer from modules.shared import opts @@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name checkpoints_loaded = collections.OrderedDict() +class ModelType(enum.Enum): + SD1 = 1 + SD2 = 2 + SDXL = 3 + SSD = 4 + SD3 = 5 + + def replace_key(d, key, new_key, value): keys = list(d.keys()) @@ -368,6 +376,36 @@ def check_fp8(model): return enable_fp8 +def set_model_type(model, state_dict): + model.is_sd1 = False + model.is_sd2 = False + model.is_sdxl = False + model.is_ssd = False + model.is_ssd3 = False + + if "model.diffusion_model.x_embedder.proj.weight" in state_dict: + model.is_sd3 = True + model.model_type = ModelType.SD3 + elif hasattr(model, 'conditioner'): + model.is_sdxl = True + + if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys(): + model.is_ssd = True + model.model_type = ModelType.SSD + else: + model.model_type = ModelType.SDXL + elif hasattr(model.cond_stage_model, 'model'): + model.is_sd2 = True + model.model_type = ModelType.SD2 + else: + model.is_sd1 = True + model.model_type = ModelType.SD1 + + +def set_model_fields(model): + if not hasattr(model, 'latent_channels'): + model.latent_channels = 4 + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") @@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - model.is_sdxl = hasattr(model, 'conditioner') - model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') - model.is_sd1 = not model.is_sdxl and not model.is_sd2 - model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + set_model_type(model, state_dict) + set_model_fields(model) + if model.is_sdxl: sd_models_xl.extend_sdxl(model) @@ -552,8 +589,7 @@ def patch_given_betas(): original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) -def repair_config(sd_config): - +def repair_config(sd_config, state_dict=None): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -563,8 +599,9 @@ def repair_config(sd_config): elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": sd_config.model.params.unet_config.params.use_fp16 = True - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" + if hasattr(sd_config.model.params, 'first_stage_config'): + if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: + sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" # For UnCLIP-L, override the hardcoded karlo directory if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): @@ -580,6 +617,7 @@ def repair_config(sd_config): sd_config.model.params.unet_config.params.use_checkpoint = False + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -715,6 +753,25 @@ def send_model_to_trash(m): devices.torch_gc() +def instantiate_from_config(config, state_dict=None): + constructor = get_obj_from_str(config["target"]) + + params = {**config.get("params", {})} + + if state_dict and "state_dict" in params and params["state_dict"] is None: + params["state_dict"] = state_dict + + return constructor(**params) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("find config") sd_config = OmegaConf.load(checkpoint_config) - repair_config(sd_config) + repair_config(sd_config, state_dict) timer.record("load config") @@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): with sd_disable_initialization.InitializeOnMeta(): - sd_model = instantiate_from_config(sd_config.model) + sd_model = instantiate_from_config(sd_config.model, state_dict) except Exception as e: errors.display(e, "creating model quickly", full_traceback=True) @@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) with sd_disable_initialization.InitializeOnMeta(): - sd_model = instantiate_from_config(sd_config.model) + sd_model = instantiate_from_config(sd_config.model, state_dict) sd_model.used_config = checkpoint_config @@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + + if hasattr(sd_model, "after_load_weights"): + sd_model.after_load_weights() + timer.record("load weights from state dict") send_model_to_device(sd_model) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9cec4f13d..7cfeca67f 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml" config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") + def is_using_v_parameterization_for_sd2(state_dict): """ @@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) + if "model.diffusion_model.x_embedder.proj.weight" in sd: + return config_sd3 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting else: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: @@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: return config_alt_diffusion_m18 diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index f911fbb68..2fce2777b 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion): is_sd1: bool """True if the model's architecture is SD 1.x""" + + is_sd3: bool + """True if the model's architecture is SD 3""" + + latent_channels: int + """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bda578cc5..b584b68a9 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample @@ -246,7 +246,7 @@ class Sampler: self.eta_infotext_field = 'Eta' self.eta_default = 1.0 - self.conditioning_key = shared.sd_model.model.conditioning_key + self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn') self.p = None self.model_wrap_cfg = None diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 64e14e0c2..cede0760a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): @property def inner_model(self): if self.model_wrap is None: - denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None) + + if denoiser_constructor is not None: + self.model_wrap = denoiser_constructor() + else: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) return self.model_wrap diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 3965e223e..c5dda7431 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -8,9 +8,9 @@ sd_vae_approx_models = {} class VAEApprox(nn.Module): - def __init__(self): + def __init__(self, latent_channels=4): super(VAEApprox, self).__init__() - self.conv1 = nn.Conv2d(4, 8, (7, 7)) + self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7)) self.conv2 = nn.Conv2d(8, 16, (5, 5)) self.conv3 = nn.Conv2d(16, 32, (3, 3)) self.conv4 = nn.Conv2d(32, 64, (3, 3)) @@ -40,7 +40,13 @@ def download_model(model_path, model_url): def model(): - model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + if shared.sd_model.is_sd3: + model_name = "vaeapprox-sd3.pt" + elif shared.sd_model.is_sdxl: + model_name = "vaeapprox-sdxl.pt" + else: + model_name = "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) if loaded_model is None: @@ -52,7 +58,7 @@ def model(): model_path = os.path.join(paths.models_path, "VAE-approx", model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) - loaded_model = VAEApprox() + loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) @@ -64,7 +70,18 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - if shared.sd_model.is_sdxl: + if shared.sd_model.is_sd3: + coeffs = [ + [-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650], + [ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889], + [ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284], + [ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039], + [ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481], + [ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259], + ] + elif shared.sd_model.is_sdxl: coeffs = [ [ 0.3448, 0.4168, 0.4395], [-0.1953, -0.0290, 0.0250], diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 808eb3624..d06253d2a 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -34,9 +34,9 @@ class Block(nn.Module): return self.fuse(self.conv(x) + self.skip(x)) -def decoder(): +def decoder(latent_channels=4): return nn.Sequential( - Clamp(), conv(4, 64), nn.ReLU(), + Clamp(), conv(latent_channels, 64), nn.ReLU(), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), @@ -44,13 +44,13 @@ def decoder(): ) -def encoder(): +def encoder(latent_channels=4): return nn.Sequential( conv(3, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), - conv(64, 4), + conv(64, latent_channels), ) @@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, decoder_path="taesd_decoder.pth"): + def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.decoder = decoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + + self.decoder = decoder(latent_channels) self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth"): + def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = encoder() + + if latent_channels is None: + latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + + self.encoder = encoder(latent_channels) self.encoder.load_state_dict( torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) @@ -87,7 +95,13 @@ def download_model(model_path, model_url): def decoder_model(): - model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_decoder.pth" + else: + model_name = "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: @@ -106,7 +120,13 @@ def decoder_model(): def encoder_model(): - model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + if shared.sd_model.is_sd3: + model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_sdxl: + model_name = "taesdxl_encoder.pth" + else: + model_name = "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) if loaded_model is None: