diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 93402bb28..c45a8d23a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # without this line the results are significantly worse / less accurate + # ensure skew-symmetric matrix oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) @@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: @@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule): return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - #if self.is_kohya: - # return self.calc_updown_kohya(orig_weight, multiplier) - #else: + # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it + #multiplier = self.multiplier() * self.calc_scale() + multiplier = self.multiplier() + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight