From 94603d9578441456cdbadea84b112014f763fe01 Mon Sep 17 00:00:00 2001 From: John Kim Date: Mon, 7 Nov 2022 00:28:55 +0000 Subject: [PATCH] Added store/restore to EMAModel, restore non-EMA weights after saving checkpoint --- diffusers_trainer.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 719078a..b51c317 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -462,6 +462,33 @@ class EMAModel: for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.data) + # From CompVis LitEMA implementation + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + del self.collected_params + gc.collect() + def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: @@ -602,6 +629,7 @@ def main(): def save_checkpoint(global_step): if rank == 0: if args.use_ema: + ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionPipeline( text_encoder=text_encoder, @@ -616,6 +644,9 @@ def main(): ) print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}') pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') + + if args.use_ema: + ema_unet.restore(unet.parameters()) # barrier torch.distributed.barrier()