fix: calculate butterfly factor

This commit is contained in:
v0xie 2024-02-07 04:51:22 -08:00
parent 9588721197
commit a4668a16b6
1 changed files with 3 additions and 0 deletions

View File

@ -57,6 +57,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.constraint = self.alpha * self.out_dim self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim self.block_size = self.out_dim // self.dim
elif self.is_boft:
self.constraint = None
self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
else: else:
self.constraint = None self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)