Merge pull request #31 from john-sungjin/fix-ema-checkpoint-saving
Restore non-EMA weights after saving checkpoint
This commit is contained in:
commit
1ea31eb71e
|
@ -462,6 +462,33 @@ class EMAModel:
|
||||||
for s_param, param in zip(self.shadow_params, parameters):
|
for s_param, param in zip(self.shadow_params, parameters):
|
||||||
param.data.copy_(s_param.data)
|
param.data.copy_(s_param.data)
|
||||||
|
|
||||||
|
# From CompVis LitEMA implementation
|
||||||
|
def store(self, parameters):
|
||||||
|
"""
|
||||||
|
Save the current parameters for restoring later.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
temporarily stored.
|
||||||
|
"""
|
||||||
|
self.collected_params = [param.clone() for param in parameters]
|
||||||
|
|
||||||
|
def restore(self, parameters):
|
||||||
|
"""
|
||||||
|
Restore the parameters stored with the `store` method.
|
||||||
|
Useful to validate the model with EMA parameters without affecting the
|
||||||
|
original optimization process. Store the parameters before the
|
||||||
|
`copy_to` method. After validation (or model saving), use this to
|
||||||
|
restore the former parameters.
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
updated with the stored parameters.
|
||||||
|
"""
|
||||||
|
for c_param, param in zip(self.collected_params, parameters):
|
||||||
|
param.data.copy_(c_param.data)
|
||||||
|
|
||||||
|
del self.collected_params
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def to(self, device=None, dtype=None) -> None:
|
def to(self, device=None, dtype=None) -> None:
|
||||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||||
Args:
|
Args:
|
||||||
|
@ -602,6 +629,7 @@ def main():
|
||||||
def save_checkpoint(global_step):
|
def save_checkpoint(global_step):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
ema_unet.store(unet.parameters())
|
||||||
ema_unet.copy_to(unet.parameters())
|
ema_unet.copy_to(unet.parameters())
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -616,6 +644,9 @@ def main():
|
||||||
)
|
)
|
||||||
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
|
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
|
||||||
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
||||||
|
|
||||||
|
if args.use_ema:
|
||||||
|
ema_unet.restore(unet.parameters())
|
||||||
# barrier
|
# barrier
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue