Merge pull request #31 from john-sungjin/fix-ema-checkpoint-saving

Restore non-EMA weights after saving checkpoint
This commit is contained in:
Anthony Mercurio 2022-11-06 16:44:11 -08:00 committed by GitHub
commit 1ea31eb71e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 0 deletions

View File

@ -462,6 +462,33 @@ class EMAModel:
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
# From CompVis LitEMA implementation
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
del self.collected_params
gc.collect()
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
@ -602,6 +629,7 @@ def main():
def save_checkpoint(global_step):
if rank == 0:
if args.use_ema:
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
@ -616,6 +644,9 @@ def main():
)
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
if args.use_ema:
ema_unet.restore(unet.parameters())
# barrier
torch.distributed.barrier()