EMA model stepping updated to keep track of current step (#64)
ema model stepping done automatically now
This commit is contained in:
parent
94566e6dd8
commit
3abf4bc439
|
@ -130,7 +130,7 @@ def main(args):
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
ema_model.step(model, global_step)
|
ema_model.step(model)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
progress_bar.set_postfix(
|
progress_bar.set_postfix(
|
||||||
|
|
|
@ -43,6 +43,7 @@ class EMAModel:
|
||||||
self.averaged_model = self.averaged_model.to(device=device)
|
self.averaged_model = self.averaged_model.to(device=device)
|
||||||
|
|
||||||
self.decay = 0.0
|
self.decay = 0.0
|
||||||
|
self.optimization_step = 0
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
def get_decay(self, optimization_step):
|
||||||
"""
|
"""
|
||||||
|
@ -57,11 +58,11 @@ class EMAModel:
|
||||||
return max(self.min_value, min(value, self.max_value))
|
return max(self.min_value, min(value, self.max_value))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, new_model, optimization_step):
|
def step(self, new_model):
|
||||||
ema_state_dict = {}
|
ema_state_dict = {}
|
||||||
ema_params = self.averaged_model.state_dict()
|
ema_params = self.averaged_model.state_dict()
|
||||||
|
|
||||||
self.decay = self.get_decay(optimization_step)
|
self.decay = self.get_decay(self.optimization_step)
|
||||||
|
|
||||||
for key, param in new_model.named_parameters():
|
for key, param in new_model.named_parameters():
|
||||||
if isinstance(param, dict):
|
if isinstance(param, dict):
|
||||||
|
@ -85,3 +86,4 @@ class EMAModel:
|
||||||
ema_state_dict[key] = param
|
ema_state_dict[key] = param
|
||||||
|
|
||||||
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
||||||
|
self.optimization_step += 1
|
||||||
|
|
Loading…
Reference in New Issue