fix: calculate butterfly factor
This commit is contained in:
parent
9588721197
commit
a4668a16b6
|
@ -57,6 +57,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||||
self.constraint = self.alpha * self.out_dim
|
self.constraint = self.alpha * self.out_dim
|
||||||
self.num_blocks = self.dim
|
self.num_blocks = self.dim
|
||||||
self.block_size = self.out_dim // self.dim
|
self.block_size = self.out_dim // self.dim
|
||||||
|
elif self.is_boft:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
|
||||||
else:
|
else:
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
Loading…
Reference in New Issue