fix distribution logic
This commit is contained in:
parent
20ce83d3e1
commit
69158fa51b
|
@ -20,6 +20,7 @@ class Accumulnator(BasePlugin):
|
|||
|
||||
# spread the grad accums
|
||||
curve = config['curve']
|
||||
steps = end_epoch - begin_epoch
|
||||
if curve == 'linear':
|
||||
accums = torch.linspace(start=begin_grad_accum,
|
||||
end=end_grad_accum,
|
||||
|
@ -28,14 +29,15 @@ class Accumulnator(BasePlugin):
|
|||
accums = torch.logspace(start=math.log(begin_grad_accum, 2),
|
||||
end=math.log(end_grad_accum, 2),
|
||||
base=2,
|
||||
steps=10).tolist()
|
||||
steps=steps).tolist()
|
||||
else:
|
||||
raise NotImplementedError(f"curve not {curve} not recognized")
|
||||
|
||||
#print(f"accums: {accums}")
|
||||
accums_per_epoch = {}
|
||||
for i in range(begin_epoch):
|
||||
accums_per_epoch[i] = begin_grad_accum
|
||||
for i in range(end_grad_accum-begin_grad_accum):
|
||||
for i in range(steps):
|
||||
#print(f"took accum {accums[i]} for epoch {i+begin_epoch}")
|
||||
accums_per_epoch[i+begin_epoch] = round(accums[i])
|
||||
self.per_epoch_grad_accum = accums_per_epoch
|
||||
|
||||
|
|
Loading…
Reference in New Issue