Fix ema decay (#1868)
* Fix ema decay and clarify nomenclature. * Rename var.
This commit is contained in:
parent
b28ab30215
commit
a6e2c1fe5c
|
@ -278,24 +278,19 @@ class EMAModel:
|
|||
self.decay = decay
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
value = (1 + optimization_step) / (10 + optimization_step)
|
||||
return 1 - min(self.decay, value)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, parameters):
|
||||
parameters = list(parameters)
|
||||
|
||||
self.optimization_step += 1
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
# Compute the decay factor for the exponential moving average.
|
||||
value = (1 + self.optimization_step) / (10 + self.optimization_step)
|
||||
one_minus_decay = 1 - min(self.decay, value)
|
||||
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
if param.requires_grad:
|
||||
tmp = self.decay * (s_param - param)
|
||||
s_param.sub_(tmp)
|
||||
s_param.sub_(one_minus_decay * (s_param - param))
|
||||
else:
|
||||
s_param.copy_(param)
|
||||
|
||||
|
|
Loading…
Reference in New Issue