95 lines
3.1 KiB
Python
95 lines
3.1 KiB
Python
import torch
|
|
|
|
|
|
def make_weight_cp(t, wa, wb):
|
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
|
|
|
|
|
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
|
up = up.reshape(up.size(0), -1)
|
|
down = down.reshape(down.size(0), -1)
|
|
if dyn_dim is not None:
|
|
up = up[:, :dyn_dim]
|
|
down = down[:dyn_dim, :]
|
|
return (up @ down).reshape(shape)
|
|
|
|
|
|
def rebuild_cp_decomposition(up, down, mid):
|
|
up = up.reshape(up.size(0), -1)
|
|
down = down.reshape(down.size(0), -1)
|
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
|
|
|
|
|
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|
'''
|
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
|
second value is higher or equal than first value.
|
|
|
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
|
secon value is a value for weight.
|
|
|
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
|
|
|
examples)
|
|
factor
|
|
-1 2 4 8 16 ...
|
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
|
'''
|
|
|
|
if factor > 0 and (dimension % factor) == 0:
|
|
m = factor
|
|
n = dimension // factor
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|
|
if factor < 0:
|
|
factor = dimension
|
|
m, n = 1, dimension
|
|
length = m + n
|
|
while m<n:
|
|
new_m = m + 1
|
|
while dimension%new_m != 0:
|
|
new_m += 1
|
|
new_n = dimension // new_m
|
|
if new_m + new_n > length or new_m>factor:
|
|
break
|
|
else:
|
|
m, n = new_m, new_n
|
|
if m > n:
|
|
n, m = m, n
|
|
return m, n
|
|
|
|
# from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/boft.py
|
|
def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]:
|
|
"""
|
|
m = 2k
|
|
n = 2**p
|
|
m*n = dim
|
|
"""
|
|
|
|
# Find the first solution and check if it is even doable
|
|
m = n = 0
|
|
while m <= factor:
|
|
m += 2
|
|
while dimension % m != 0 and m < dimension:
|
|
m += 2
|
|
if m > factor:
|
|
break
|
|
if sum(int(i) for i in f"{dimension//m:b}") == 1:
|
|
n = dimension // m
|
|
|
|
if n == 0:
|
|
raise ValueError(
|
|
f"It is impossible to decompose {dimension} with factor {factor} under BOFT constrains."
|
|
)
|
|
|
|
#log_butterfly_factorize(dimension, factor, (dimension // n, n))
|
|
return dimension // n, n
|