fix: ignore calc_scale() for COFT which has very small alpha
This commit is contained in:
parent
7edd50f304
commit
d6d0b22e66
|
@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
|
||||
|
||||
if not is_other_linear:
|
||||
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||
# orig_weight=orig_weight.permute(1, 0)
|
||||
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
||||
# without this line the results are significantly worse / less accurate
|
||||
# ensure skew-symmetric matrix
|
||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)
|
||||
|
||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
|
||||
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||
# orig_weight=orig_weight.permute(1, 0)
|
||||
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
output_shape = orig_weight.shape
|
||||
else:
|
||||
|
@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
multiplier = self.multiplier() * self.calc_scale()
|
||||
#if self.is_kohya:
|
||||
# return self.calc_updown_kohya(orig_weight, multiplier)
|
||||
#else:
|
||||
# if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it
|
||||
#multiplier = self.multiplier() * self.calc_scale()
|
||||
multiplier = self.multiplier()
|
||||
|
||||
return self.calc_updown_kb(orig_weight, multiplier)
|
||||
|
||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||
|
|
Loading…
Reference in New Issue