Fix ema decay (#1868)

* Fix ema decay and clarify nomenclature.

* Rename var.
This commit is contained in:
Pedro Cuenca 2022-12-30 12:42:42 +01:00 committed by GitHub
parent b28ab30215
commit a6e2c1fe5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 10 deletions

View File

@ -278,24 +278,19 @@ class EMAModel:
self.decay = decay
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
# Compute the decay factor for the exponential moving average.
value = (1 + self.optimization_step) / (10 + self.optimization_step)
one_minus_decay = 1 - min(self.decay, value)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
s_param.sub_(one_minus_decay * (s_param - param))
else:
s_param.copy_(param)