Added store/restore to EMAModel, restore non-EMA weights after saving checkpoint

This commit is contained in:
John Kim 2022-11-07 00:28:55 +00:00
parent 517f3e154f
commit 94603d9578
1 changed files with 31 additions and 0 deletions

View File

@ -462,6 +462,33 @@ class EMAModel:
for s_param, param in zip(self.shadow_params, parameters): for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data) param.data.copy_(s_param.data)
# From CompVis LitEMA implementation
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
del self.collected_params
gc.collect()
def to(self, device=None, dtype=None) -> None: def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`. r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args: Args:
@ -602,6 +629,7 @@ def main():
def save_checkpoint(global_step): def save_checkpoint(global_step):
if rank == 0: if rank == 0:
if args.use_ema: if args.use_ema:
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,
@ -616,6 +644,9 @@ def main():
) )
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}') print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
if args.use_ema:
ema_unet.restore(unet.parameters())
# barrier # barrier
torch.distributed.barrier() torch.distributed.barrier()