diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8a6639cf0..084c41d08 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -14,14 +14,28 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] - multipliers = [] + te_multipliers = [] + unet_multipliers = [] + dyn_dims = [] for params in params_list: assert params.items - names.append(params.items[0]) - multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + names.append(params.positional[0]) - networks.load_networks(names, multipliers) + te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 + te_multiplier = float(params.named.get("te", te_multiplier)) + + unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0 + unet_multiplier = float(params.named.get("unet", unet_multiplier)) + + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + + te_multipliers.append(te_multiplier) + unet_multipliers.append(unet_multiplier) + dyn_dims.append(dyn_dim) + + networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) if shared.opts.lora_add_hashes_to_infotext: network_hashes = [] diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 9ea499fbd..279b34bc9 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -13,3 +13,9 @@ def rebuild_conventional(up, down, shape, dyn_dim=None): up = up[:, :dyn_dim] down = down[:dyn_dim, :] return (up @ down).reshape(shape) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 4ac637221..fe42dbdd4 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -68,7 +68,9 @@ class Network: # LoraModule def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk - self.multiplier = 1.0 + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None self.modules = {} self.mtime = None @@ -88,6 +90,42 @@ class NetworkModule: self.sd_key = weights.sd_key self.sd_module = weights.sd_module + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + return updown * self.calc_scale() * self.multiplier() + def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py new file mode 100644 index 000000000..f0d8a6e05 --- /dev/null +++ b/extensions-builtin/Lora/network_full.py @@ -0,0 +1,23 @@ +import lyco_helpers +import network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + + return None + + +class NetworkModuleFull(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + + def calc_updown(self, orig_weight): + output_shape = self.weight.shape + updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 799bb3bc0..5fcb0695f 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -1,6 +1,5 @@ import lyco_helpers import network -import network_lyco class ModuleTypeHada(network.ModuleType): @@ -11,7 +10,7 @@ class ModuleTypeHada(network.ModuleType): return None -class NetworkModuleHada(network_lyco.NetworkModuleLyco): +class NetworkModuleHada(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index d8806da02..7edc42497 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -1,5 +1,4 @@ import network -import network_lyco class ModuleTypeIa3(network.ModuleType): @@ -10,7 +9,7 @@ class ModuleTypeIa3(network.ModuleType): return None -class NetworkModuleIa3(network_lyco.NetworkModuleLyco): +class NetworkModuleIa3(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index f17319242..920062e24 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -2,7 +2,6 @@ import torch import lyco_helpers import network -import network_lyco class ModuleTypeLokr(network.ModuleType): @@ -22,7 +21,7 @@ def make_kron(orig_shape, w1, w2): return torch.kron(w1, w2).reshape(orig_shape) -class NetworkModuleLokr(network_lyco.NetworkModuleLyco): +class NetworkModuleLokr(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index b2d965373..26c0a72c2 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -1,5 +1,6 @@ import torch +import lyco_helpers import network from modules import devices @@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - self.up = self.create_module(weights.w["lora_up.weight"]) - self.down = self.create_module(weights.w["lora_down.weight"]) - self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) - def create_module(self, weight, none_ok=False): if weight is None and none_ok: return None - if type(self.sd_module) == torch.nn.Linear: + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.MultiheadAttention: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - return None + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) module.weight.copy_(weight) module.to(device=devices.cpu, dtype=devices.dtype) @@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule): return module - def calc_updown(self, target): - up = self.up.weight.to(target.device, dtype=target.dtype) - down = self.down.weight.to(target.device, dtype=target.dtype) + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] else: - updown = up @ down + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) - - return updown + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y): - self.up.to(device=devices.device) - self.down.to(device=devices.device) + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) - return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py deleted file mode 100644 index fc135314b..000000000 --- a/extensions-builtin/Lora/network_lyco.py +++ /dev/null @@ -1,35 +0,0 @@ -import network - - -class NetworkModuleLyco(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.dim = None - self.bias = weights.w.get("bias") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - - def finalize_updown(self, updown, orig_weight, output_shape): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - scale = ( - self.scale if self.scale is not None - else self.alpha / self.dim if self.dim is not None and self.alpha is not None - else 1.0 - ) - - return updown * scale * self.network.multiplier - diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 1b358561c..401430e86 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -6,6 +6,7 @@ import network_lora import network_hada import network_ia3 import network_lokr +import network_full import torch from typing import Union @@ -17,6 +18,7 @@ module_types = [ network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), ] @@ -52,6 +54,15 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] + if match(m, r"lora_unet_conv_in(.*)"): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" @@ -179,7 +190,7 @@ def load_network(name, network_on_disk): return net -def load_networks(names, multipliers=None): +def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): already_loaded = {} for net in loaded_networks: @@ -218,7 +229,9 @@ def load_networks(names, multipliers=None): print(f"Couldn't find network with name {name}") continue - net.multiplier = multipliers[i] if multipliers else 1.0 + net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 + net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) if failed_to_load_networks: @@ -250,7 +263,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn return current_names = getattr(self, "network_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None: @@ -288,9 +301,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_k = module_k.calc_updown(self.in_proj_weight) updown_v = module_v.calc_updown(self.in_proj_weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out = module_out.calc_updown(self.out_proj.weight) self.in_proj_weight += updown_qkv - self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + self.out_proj.weight += updown_out continue if module is None: