fix distribution logic
This commit is contained in:
parent
20ce83d3e1
commit
69158fa51b
|
@ -20,6 +20,7 @@ class Accumulnator(BasePlugin):
|
||||||
|
|
||||||
# spread the grad accums
|
# spread the grad accums
|
||||||
curve = config['curve']
|
curve = config['curve']
|
||||||
|
steps = end_epoch - begin_epoch
|
||||||
if curve == 'linear':
|
if curve == 'linear':
|
||||||
accums = torch.linspace(start=begin_grad_accum,
|
accums = torch.linspace(start=begin_grad_accum,
|
||||||
end=end_grad_accum,
|
end=end_grad_accum,
|
||||||
|
@ -28,14 +29,15 @@ class Accumulnator(BasePlugin):
|
||||||
accums = torch.logspace(start=math.log(begin_grad_accum, 2),
|
accums = torch.logspace(start=math.log(begin_grad_accum, 2),
|
||||||
end=math.log(end_grad_accum, 2),
|
end=math.log(end_grad_accum, 2),
|
||||||
base=2,
|
base=2,
|
||||||
steps=10).tolist()
|
steps=steps).tolist()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"curve not {curve} not recognized")
|
raise NotImplementedError(f"curve not {curve} not recognized")
|
||||||
|
#print(f"accums: {accums}")
|
||||||
accums_per_epoch = {}
|
accums_per_epoch = {}
|
||||||
for i in range(begin_epoch):
|
for i in range(begin_epoch):
|
||||||
accums_per_epoch[i] = begin_grad_accum
|
accums_per_epoch[i] = begin_grad_accum
|
||||||
for i in range(end_grad_accum-begin_grad_accum):
|
for i in range(steps):
|
||||||
|
#print(f"took accum {accums[i]} for epoch {i+begin_epoch}")
|
||||||
accums_per_epoch[i+begin_epoch] = round(accums[i])
|
accums_per_epoch[i+begin_epoch] = round(accums[i])
|
||||||
self.per_epoch_grad_accum = accums_per_epoch
|
self.per_epoch_grad_accum = accums_per_epoch
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue