EMA model stepping updated to keep track of current step (#64)

ema model stepping done automatically now
This commit is contained in:
Tanishq Abraham 2022-07-04 02:53:15 -07:00 committed by GitHub
parent 94566e6dd8
commit 3abf4bc439
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 3 deletions

View File

@ -130,7 +130,7 @@ def main(args):
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
ema_model.step(model, global_step)
ema_model.step(model)
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix(

View File

@ -43,6 +43,7 @@ class EMAModel:
self.averaged_model = self.averaged_model.to(device=device)
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
@ -57,11 +58,11 @@ class EMAModel:
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model, optimization_step):
def step(self, new_model):
ema_state_dict = {}
ema_params = self.averaged_model.state_dict()
self.decay = self.get_decay(optimization_step)
self.decay = self.get_decay(self.optimization_step)
for key, param in new_model.named_parameters():
if isinstance(param, dict):
@ -85,3 +86,4 @@ class EMAModel:
ema_state_dict[key] = param
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
self.optimization_step += 1