From 7e5cdaab4b386621a186999e7348f6d0af7317a7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 15 Jul 2024 08:31:55 +0300 Subject: [PATCH] SD3 lora support --- extensions-builtin/Lora/network.py | 6 +- extensions-builtin/Lora/network_lora.py | 10 ++- extensions-builtin/Lora/networks.py | 96 +++++++++++++++++++------ modules/models/sd3/mmdit.py | 5 +- modules/models/sd3/sd3_impls.py | 1 + modules/models/sd3/sd3_model.py | 12 ++++ 6 files changed, 106 insertions(+), 24 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 3c39c49d7..98ff367fd 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from modules import sd_models, cache, errors, hashes, shared +import modules.models.sd3.mmdit NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -114,7 +115,10 @@ class NetworkModule: self.sd_key = weights.sd_key self.sd_module = weights.sd_module - if hasattr(self.sd_module, 'weight'): + if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear): + s = self.sd_module.weight.shape + self.shape = (s[0] // 3, s[1]) + elif hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape elif isinstance(self.sd_module, nn.MultiheadAttention): # For now, only self-attn use Pytorch's MHA diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 4cc402951..a7a088949 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -1,6 +1,7 @@ import torch import lyco_helpers +import modules.models.sd3.mmdit import network from modules import devices @@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType): if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): return NetworkModuleLora(net, weights) + if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]): + w = weights.w.copy() + weights.w.clear() + weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]}) + + return NetworkModuleLora(net, weights) + return None @@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule): if weight is None and none_ok: return None - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] if is_linear: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 9ed8fa435..4ad98714b 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -20,6 +20,7 @@ from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack import modules.textual_inversion.textual_inversion as textual_inversion +import modules.models.sd3.mmdit from lora_logger import logger @@ -166,12 +167,26 @@ def load_network(name, network_on_disk): keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + if hasattr(shared.sd_model, 'diffusers_weight_map'): + diffusers_weight_map = shared.sd_model.diffusers_weight_map + elif hasattr(shared.sd_model, 'diffusers_weight_mapping'): + diffusers_weight_map = {} + for k, v in shared.sd_model.diffusers_weight_mapping(): + diffusers_weight_map[k] = v + shared.sd_model.diffusers_weight_map = diffusers_weight_map + else: + diffusers_weight_map = None matched_networks = {} bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, _, network_part = key_network.partition(".") + + if diffusers_weight_map: + key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2) + network_part = network_name + '.' + network_weight + else: + key_network_without_network_parts, _, network_part = key_network.partition(".") if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) @@ -183,7 +198,11 @@ def load_network(name, network_on_disk): emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict - key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + if diffusers_weight_map: + key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts) + else: + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: @@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No purge_networks_from_memory() +def allowed_layer_without_weight(layer): + if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine: + return True + + return False + + +def store_weights_backup(weight): + if weight is None: + return None + + return weight.to(devices.cpu, copy=True) + + +def restore_weights_backup(obj, field, weight): + if weight is None: + setattr(obj, field, None) + return + + getattr(obj, field).copy_(weight) + + def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) @@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li if weights_backup is not None: if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) + restore_weights_backup(self, 'in_proj_weight', weights_backup[0]) + restore_weights_backup(self.out_proj, 'weight', weights_backup[0]) else: - self.weight.copy_(weights_backup) + restore_weights_backup(self, 'weight', weights_backup) - if bias_backup is not None: - if isinstance(self, torch.nn.MultiheadAttention): - self.out_proj.bias.copy_(bias_backup) - else: - self.bias.copy_(bias_backup) + if isinstance(self, torch.nn.MultiheadAttention): + restore_weights_backup(self.out_proj, 'bias', bias_backup) else: - if isinstance(self, torch.nn.MultiheadAttention): - self.out_proj.bias = None - else: - self.bias = None + restore_weights_backup(self, 'bias', bias_backup) def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): @@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None and wanted_names != (): - if current_names != (): - raise RuntimeError("no backup weights found and current weights are not unchanged") + if current_names != () and not allowed_layer_without_weight(self): + raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight)) else: - weights_backup = self.weight.to(devices.cpu, copy=True) + weights_backup = store_weights_backup(self.weight) self.network_weights_backup = weights_backup bias_backup = getattr(self, "network_bias_backup", None) if bias_backup is None and wanted_names != (): if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: - bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) + bias_backup = store_weights_backup(self.out_proj) elif getattr(self, 'bias', None) is not None: - bias_backup = self.bias.to(devices.cpu, copy=True) + bias_backup = store_weights_backup(self.bias) else: bias_backup = None @@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # Only report if bias is not None and current bias are not unchanged. if bias_backup is not None and current_names != (): raise RuntimeError("no backup bias found and current bias are not unchanged") + self.network_bias_backup = bias_backup if current_names != wanted_names: @@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn for net in loaded_networks: module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight'): + if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear): try: with torch.no_grad(): if getattr(self, 'fp16_weight', None) is None: @@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn continue + if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v: + try: + with torch.no_grad(): + # Send "real" orig_weight into MHA's lora module + qw, kw, vw = self.weight.chunk(3, 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + del qw, kw, vw + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + self.weight += updown_qkv + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue + if module is None: continue diff --git a/modules/models/sd3/mmdit.py b/modules/models/sd3/mmdit.py index 4d2b85551..8ddf49a4e 100644 --- a/modules/models/sd3/mmdit.py +++ b/modules/models/sd3/mmdit.py @@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module): ################################################################################# +class QkvLinear(torch.nn.Linear): + pass + def split_qkv(qkv, head_dim): qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) return qkv[0], qkv[1], qkv[2] @@ -202,7 +205,7 @@ class SelfAttention(nn.Module): self.num_heads = num_heads self.head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) if not pre_only: self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) assert attn_mode in self.ATTENTION_MODES diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py index e2f6cad5b..59f11b2cb 100644 --- a/modules/models/sd3/sd3_impls.py +++ b/modules/models/sd3/sd3_impls.py @@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module): } self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype) self.model_sampling = ModelSamplingDiscreteFlow(shift=shift) + self.depth = depth def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index dbec8168f..37cf85eb3 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module): def fix_dimensions(self, width, height): return width // 16 * 16, height // 16 * 16 + + def diffusers_weight_mapping(self): + for i in range(self.model.depth): + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"