refactor: use forward hook instead of custom forward
This commit is contained in:
parent
0550659ce6
commit
2d8c894b27
|
@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
# how do we revert this to unload the weights?
|
# how do we revert this to unload the weights?
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module[0].forward
|
self.org_forward = self.org_module[0].forward
|
||||||
self.org_module[0].forward = self.forward
|
#self.org_module[0].forward = self.forward
|
||||||
|
self.org_module[0].register_forward_hook(self.forward_hook)
|
||||||
|
|
||||||
def get_weight(self, oft_blocks, multiplier=None):
|
def get_weight(self, oft_blocks, multiplier=None):
|
||||||
|
self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
|
@ -66,14 +68,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
output_shape = self.oft_blocks.shape
|
output_shape = self.oft_blocks.shape
|
||||||
|
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
def forward(self, x, y=None):
|
def forward_hook(self, module, args, output):
|
||||||
x = self.org_forward(x)
|
#print(f'Forward hook in {self.network_key} called')
|
||||||
if self.multiplier() == 0.0:
|
x = output
|
||||||
return x
|
|
||||||
|
|
||||||
# calculating R here is excruciatingly slow
|
|
||||||
#R = self.get_weight().to(x.device, dtype=x.dtype)
|
|
||||||
R = self.R.to(x.device, dtype=x.dtype)
|
R = self.R.to(x.device, dtype=x.dtype)
|
||||||
|
|
||||||
if x.dim() == 4:
|
if x.dim() == 4:
|
||||||
|
@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
else:
|
else:
|
||||||
x = torch.matmul(x, R)
|
x = torch.matmul(x, R)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# def forward(self, x, y=None):
|
||||||
|
# x = self.org_forward(x)
|
||||||
|
# if self.multiplier() == 0.0:
|
||||||
|
# return x
|
||||||
|
|
||||||
|
# # calculating R here is excruciatingly slow
|
||||||
|
# #R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||||
|
# R = self.R.to(x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
# if x.dim() == 4:
|
||||||
|
# x = x.permute(0, 2, 3, 1)
|
||||||
|
# x = torch.matmul(x, R)
|
||||||
|
# x = x.permute(0, 3, 1, 2)
|
||||||
|
# else:
|
||||||
|
# x = torch.matmul(x, R)
|
||||||
|
# return x
|
||||||
|
|
Loading…
Reference in New Issue