From f6c8201e5663ca2182a66c8eca63ce4801d52849 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 19:35:15 -0700 Subject: [PATCH] refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb --- extensions-builtin/Lora/lyco_helpers.py | 47 +++++++++ extensions-builtin/Lora/network_oft.py | 131 ++++++------------------ 2 files changed, 77 insertions(+), 101 deletions(-) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 279b34bc9..1679a0ce6 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid): up = up.reshape(up.size(0), -1) down = down.reshape(down.size(0), -1) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) + + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 979a20476..2be67fe53 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,7 +1,7 @@ import torch import network +from lyco_helpers import factorization from einops import rearrange -from modules import devices class ModuleTypeOFT(network.ModuleType): @@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType): return None -# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): @@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule): is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + if is_linear: self.out_dim = self.sd_module.out_features - #elif hasattr(self.sd_module, "embed_dim"): - # self.out_dim = self.sd_module.embed_dim - #else: - # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim - #self.org_weight = self.org_module[0].weight -# if hasattr(self.sd_module, "in_proj_weight"): -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] -# if hasattr(self.sd_module, "out_proj_weight"): -# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] -# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: raise ValueError("sd_module must be Linear or Conv") - if self.is_kohya: self.num_blocks = self.dim self.block_size = self.out_dim // self.num_blocks self.constraint = self.alpha * self.out_dim - #elif is_linear or is_conv: else: self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - - # if is_other_linear: - # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) - # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - # with torch.no_grad(): - # if weight.shape != module.weight.shape: - # weight = weight.reshape(module.weight.shape) - # module.weight.copy_(weight) - # module.to(device=devices.cpu, dtype=devices.dtype) - # module.weight.requires_grad_(False) - # self.lin_module = module - #return module - def merge_weight(self, R_weight, org_weight): R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) if org_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) else: weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - #weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - #) return weight def get_weight(self, oft_blocks, multiplier=None): @@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule): block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I R = torch.block_diag(*block_R_weighted) return R - #return self.oft_blocks + def calc_updown_kohya(self, orig_weight, multiplier): + R = self.get_weight(self.oft_blocks, multiplier) + merged_weight = self.merge_weight(R, orig_weight) - def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] - if self.is_kohya and not is_other_linear: - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) - elif not self.is_kohya and not is_other_linear: + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + orig_weight = orig_weight + return self.finalize_updown(updown, orig_weight, output_shape) + + def calc_updown_kb(self, orig_weight, multiplier): + is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] + + if not is_other_linear: if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) merged_weight = torch.einsum( 'k n m, k n ... -> k m ...', R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - merged_weight + merged_weight ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: orig_weight=orig_weight.permute(1, 0) - #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: - # skip for now + # FIXME: skip MultiheadAttention for now updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) output_shape = (orig_weight.shape[1], orig_weight.shape[1]) - #if self.lin_module is not None: - # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) - # weight = torch.mul(torch.mul(R, multiplier), orig_weight) - #else: - - orig_weight = orig_weight - return self.finalize_updown(updown, orig_weight, output_shape) + def calc_updown(self, orig_weight): + multiplier = self.multiplier() * self.calc_scale() + if self.is_kohya: + return self.calc_updown_kohya(orig_weight, multiplier) + else: + return self.calc_updown_kb(orig_weight, multiplier) + # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): #return super().finalize_updown(updown, orig_weight, output_shape, ex_bias) @@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n -