For no constraint
This commit is contained in:
parent
64179c3221
commit
c4afdb7895
|
@ -27,7 +27,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
# kohya-ss/New LyCORIS OFT/BOFT
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w.get("alpha", self.alpha) # alpha is constraint
|
||||
self.alpha = weights.w.get("alpha", None) # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# Old LyCORIS OFT
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
|
@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
self.constraint = (1 if self.alpha is None else self.alpha) * self.out_dim
|
||||
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||
if self.is_R:
|
||||
self.constraint = None
|
||||
self.block_size = self.dim
|
||||
|
@ -73,9 +73,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
|
||||
if not self.is_R:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
if self.constraint != 0:
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||
|
||||
R = oft_blocks.to(orig_weight.device)
|
||||
|
|
Loading…
Reference in New Issue