Added store/restore to EMAModel, restore non-EMA weights after saving checkpoint
This commit is contained in:
parent
517f3e154f
commit
94603d9578
|
@ -462,6 +462,33 @@ class EMAModel:
|
|||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
# From CompVis LitEMA implementation
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
del self.collected_params
|
||||
gc.collect()
|
||||
|
||||
def to(self, device=None, dtype=None) -> None:
|
||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||
Args:
|
||||
|
@ -602,6 +629,7 @@ def main():
|
|||
def save_checkpoint(global_step):
|
||||
if rank == 0:
|
||||
if args.use_ema:
|
||||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
|
@ -616,6 +644,9 @@ def main():
|
|||
)
|
||||
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
|
||||
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
||||
|
||||
if args.use_ema:
|
||||
ema_unet.restore(unet.parameters())
|
||||
# barrier
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
|
Loading…
Reference in New Issue