22 lines
647 B
Python
22 lines
647 B
Python
import torch
|
|
|
|
|
|
def make_weight_cp(t, wa, wb):
|
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
|
|
|
|
|
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
|
up = up.reshape(up.size(0), -1)
|
|
down = down.reshape(down.size(0), -1)
|
|
if dyn_dim is not None:
|
|
up = up[:, :dyn_dim]
|
|
down = down[:dyn_dim, :]
|
|
return (up @ down).reshape(shape)
|
|
|
|
|
|
def rebuild_cp_decomposition(up, down, mid):
|
|
up = up.reshape(up.size(0), -1)
|
|
down = down.reshape(down.size(0), -1)
|
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|