fix: ignore calc_scale() for COFT which has very small alpha

This commit is contained in:
v0xie 2023-11-15 03:08:50 -08:00
parent 7edd50f304
commit d6d0b22e66
1 changed files with 5 additions and 11 deletions

View File

@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule):
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
if not is_other_linear: if not is_other_linear:
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
# orig_weight=orig_weight.permute(1, 0)
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
# without this line the results are significantly worse / less accurate # ensure skew-symmetric matrix
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) oft_blocks = oft_blocks - oft_blocks.transpose(1, 2)
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule):
) )
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
#if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
# orig_weight=orig_weight.permute(1, 0)
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape output_shape = orig_weight.shape
else: else:
@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule):
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
multiplier = self.multiplier() * self.calc_scale() # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it
#if self.is_kohya: #multiplier = self.multiplier() * self.calc_scale()
# return self.calc_updown_kohya(orig_weight, multiplier) multiplier = self.multiplier()
#else:
return self.calc_updown_kb(orig_weight, multiplier) return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight # override to remove the multiplier/scale factor; it's already multiplied in get_weight