SD3 lora support

This commit is contained in:
AUTOMATIC1111 2024-07-15 08:31:55 +03:00
parent b2453d280a
commit 7e5cdaab4b
6 changed files with 106 additions and 24 deletions

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared
import modules.models.sd3.mmdit
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@ -114,7 +115,10 @@ class NetworkModule:
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if hasattr(self.sd_module, 'weight'):
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
s = self.sd_module.weight.shape
self.shape = (s[0] // 3, s[1])
elif hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA

View File

@ -1,6 +1,7 @@
import torch
import lyco_helpers
import modules.models.sd3.mmdit
import network
from modules import devices
@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
w = weights.w.copy()
weights.w.clear()
weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
return NetworkModuleLora(net, weights)
return None
@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:

View File

@ -20,6 +20,7 @@ from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.models.sd3.mmdit
from lora_logger import logger
@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
if hasattr(shared.sd_model, 'diffusers_weight_map'):
diffusers_weight_map = shared.sd_model.diffusers_weight_map
elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):
diffusers_weight_map = {}
for k, v in shared.sd_model.diffusers_weight_mapping():
diffusers_weight_map[k] = v
shared.sd_model.diffusers_weight_map = diffusers_weight_map
else:
diffusers_weight_map = None
matched_networks = {}
bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, _, network_part = key_network.partition(".")
if diffusers_weight_map:
key_network_without_network_parts, network_name, network_weight = key_network.rsplit(".", 2)
network_part = network_name + '.' + network_weight
else:
key_network_without_network_parts, _, network_part = key_network.partition(".")
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
if diffusers_weight_map:
key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)
else:
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory()
def allowed_layer_without_weight(layer):
if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:
return True
return False
def store_weights_backup(weight):
if weight is None:
return None
return weight.to(devices.cpu, copy=True)
def restore_weights_backup(obj, field, weight):
if weight is None:
setattr(obj, field, None)
return
getattr(obj, field).copy_(weight)
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
restore_weights_backup(self.out_proj, 'weight', weights_backup[0])
else:
self.weight.copy_(weights_backup)
restore_weights_backup(self, 'weight', weights_backup)
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias.copy_(bias_backup)
else:
self.bias.copy_(bias_backup)
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self.out_proj, 'bias', bias_backup)
else:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
restore_weights_backup(self, 'bias', bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")
if current_names != () and not allowed_layer_without_weight(self):
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
weights_backup = store_weights_backup(self.weight)
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
bias_backup = store_weights_backup(self.out_proj)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
bias_backup = store_weights_backup(self.bias)
else:
bias_backup = None
@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup
if current_names != wanted_names:
@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):
try:
with torch.no_grad():
if getattr(self, 'fp16_weight', None) is None:
@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:
try:
with torch.no_grad():
# Send "real" orig_weight into MHA's lora module
qw, kw, vw = self.weight.chunk(3, 0)
updown_q, _ = module_q.calc_updown(qw)
updown_k, _ = module_k.calc_updown(kw)
updown_v, _ = module_v.calc_updown(vw)
del qw, kw, vw
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if module is None:
continue

View File

@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
#################################################################################
class QkvLinear(torch.nn.Linear):
pass
def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2]
@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
if not pre_only:
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
assert attn_mode in self.ATTENTION_MODES

View File

@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
}
self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)
self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)
self.depth = depth
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()

View File

@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16
def diffusers_weight_mapping(self):
for i in range(self.model.depth):
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"