fix: ignore calc_scale() for COFT which has very small alpha
This commit is contained in:
parent
7edd50f304
commit
d6d0b22e66
|
@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
|
||||||
|
|
||||||
if not is_other_linear:
|
if not is_other_linear:
|
||||||
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
|
||||||
# orig_weight=orig_weight.permute(1, 0)
|
|
||||||
|
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
# without this line the results are significantly worse / less accurate
|
# ensure skew-symmetric matrix
|
||||||
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)
|
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
)
|
)
|
||||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
|
||||||
# orig_weight=orig_weight.permute(1, 0)
|
|
||||||
|
|
||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
else:
|
else:
|
||||||
|
@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
multiplier = self.multiplier() * self.calc_scale()
|
# if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it
|
||||||
#if self.is_kohya:
|
#multiplier = self.multiplier() * self.calc_scale()
|
||||||
# return self.calc_updown_kohya(orig_weight, multiplier)
|
multiplier = self.multiplier()
|
||||||
#else:
|
|
||||||
return self.calc_updown_kb(orig_weight, multiplier)
|
return self.calc_updown_kb(orig_weight, multiplier)
|
||||||
|
|
||||||
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
|
||||||
|
|
Loading…
Reference in New Issue