fix distribution logic

This commit is contained in:
Damian Stewart 2023-10-22 22:02:38 +02:00
parent 20ce83d3e1
commit 69158fa51b
1 changed files with 5 additions and 3 deletions

View File

@ -20,6 +20,7 @@ class Accumulnator(BasePlugin):
# spread the grad accums # spread the grad accums
curve = config['curve'] curve = config['curve']
steps = end_epoch - begin_epoch
if curve == 'linear': if curve == 'linear':
accums = torch.linspace(start=begin_grad_accum, accums = torch.linspace(start=begin_grad_accum,
end=end_grad_accum, end=end_grad_accum,
@ -28,14 +29,15 @@ class Accumulnator(BasePlugin):
accums = torch.logspace(start=math.log(begin_grad_accum, 2), accums = torch.logspace(start=math.log(begin_grad_accum, 2),
end=math.log(end_grad_accum, 2), end=math.log(end_grad_accum, 2),
base=2, base=2,
steps=10).tolist() steps=steps).tolist()
else: else:
raise NotImplementedError(f"curve not {curve} not recognized") raise NotImplementedError(f"curve not {curve} not recognized")
#print(f"accums: {accums}")
accums_per_epoch = {} accums_per_epoch = {}
for i in range(begin_epoch): for i in range(begin_epoch):
accums_per_epoch[i] = begin_grad_accum accums_per_epoch[i] = begin_grad_accum
for i in range(end_grad_accum-begin_grad_accum): for i in range(steps):
#print(f"took accum {accums[i]} for epoch {i+begin_epoch}")
accums_per_epoch[i+begin_epoch] = round(accums[i]) accums_per_epoch[i+begin_epoch] = round(accums[i])
self.per_epoch_grad_accum = accums_per_epoch self.per_epoch_grad_accum = accums_per_epoch