refactor: use forward hook instead of custom forward
This commit is contained in:
parent
0550659ce6
commit
2d8c894b27
|
@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
# how do we revert this to unload the weights?
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
#self.org_module[0].forward = self.forward
|
||||
self.org_module[0].register_forward_hook(self.forward_hook)
|
||||
|
||||
def get_weight(self, oft_blocks, multiplier=None):
|
||||
self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
|
@ -66,14 +68,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
output_shape = self.oft_blocks.shape
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
||||
def forward(self, x, y=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier() == 0.0:
|
||||
return x
|
||||
|
||||
# calculating R here is excruciatingly slow
|
||||
#R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
|
||||
def forward_hook(self, module, args, output):
|
||||
#print(f'Forward hook in {self.network_key} called')
|
||||
x = output
|
||||
R = self.R.to(x.device, dtype=x.dtype)
|
||||
|
||||
if x.dim() == 4:
|
||||
|
@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
|
||||
# def forward(self, x, y=None):
|
||||
# x = self.org_forward(x)
|
||||
# if self.multiplier() == 0.0:
|
||||
# return x
|
||||
|
||||
# # calculating R here is excruciatingly slow
|
||||
# #R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
# R = self.R.to(x.device, dtype=x.dtype)
|
||||
|
||||
# if x.dim() == 4:
|
||||
# x = x.permute(0, 2, 3, 1)
|
||||
# x = torch.matmul(x, R)
|
||||
# x = x.permute(0, 3, 1, 2)
|
||||
# else:
|
||||
# x = torch.matmul(x, R)
|
||||
# return x
|
||||
|
|
Loading…
Reference in New Issue