EMA model stepping updated to keep track of current step (#64)
ema model stepping done automatically now
This commit is contained in:
parent
94566e6dd8
commit
3abf4bc439
|
@ -130,7 +130,7 @@ def main(args):
|
|||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
ema_model.step(model, global_step)
|
||||
ema_model.step(model)
|
||||
optimizer.zero_grad()
|
||||
progress_bar.update(1)
|
||||
progress_bar.set_postfix(
|
||||
|
|
|
@ -43,6 +43,7 @@ class EMAModel:
|
|||
self.averaged_model = self.averaged_model.to(device=device)
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
|
@ -57,11 +58,11 @@ class EMAModel:
|
|||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model, optimization_step):
|
||||
def step(self, new_model):
|
||||
ema_state_dict = {}
|
||||
ema_params = self.averaged_model.state_dict()
|
||||
|
||||
self.decay = self.get_decay(optimization_step)
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
for key, param in new_model.named_parameters():
|
||||
if isinstance(param, dict):
|
||||
|
@ -85,3 +86,4 @@ class EMAModel:
|
|||
ema_state_dict[key] = param
|
||||
|
||||
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
||||
self.optimization_step += 1
|
||||
|
|
Loading…
Reference in New Issue