Merge pull request #14726 from v0xie/fix-oft-device
Fix kohya-ss OFT network wrong device for eye and constraint
This commit is contained in:
commit
358e9e2847
|
@ -57,12 +57,12 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||||
|
|
||||||
if self.is_kohya:
|
if self.is_kohya:
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue