fix: AttributeError when attempting to reshape rescale by org_module weight
This commit is contained in:
parent
3e0146f9bd
commit
07805cbeee
|
@ -36,13 +36,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
# self.alpha is unused
|
# self.alpha is unused
|
||||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
# LyCORIS BOFT
|
|
||||||
if self.oft_blocks.dim() == 4:
|
|
||||||
self.is_boft = True
|
|
||||||
self.rescale = weights.w.get('rescale', None)
|
|
||||||
if self.rescale is not None:
|
|
||||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
|
||||||
|
|
||||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
|
@ -54,6 +47,13 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
elif is_other_linear:
|
elif is_other_linear:
|
||||||
self.out_dim = self.sd_module.embed_dim
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
# LyCORIS BOFT
|
||||||
|
if self.oft_blocks.dim() == 4:
|
||||||
|
self.is_boft = True
|
||||||
|
self.rescale = weights.w.get('rescale', None)
|
||||||
|
if self.rescale is not None and not is_other_linear:
|
||||||
|
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||||
|
|
||||||
self.num_blocks = self.dim
|
self.num_blocks = self.dim
|
||||||
self.block_size = self.out_dim // self.dim
|
self.block_size = self.out_dim // self.dim
|
||||||
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||||
|
|
Loading…
Reference in New Issue