[textual_inversion] Fix resuming state when using gradient checkpointing (#2072)
* Fix resuming state when using gradient checkpointing. Also, allow --resume_from_checkpoint to be used when the checkpoint does not yet exist (a normal training run will be started). * style
This commit is contained in:
parent
7d8b4f7f8e
commit
f4dddaf5ee
|
@ -597,7 +597,7 @@ def main():
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
if accelerator.mixed_precision == "fp16":
|
if accelerator.mixed_precision == "fp16":
|
||||||
|
@ -643,14 +643,21 @@ def main():
|
||||||
dirs = os.listdir(args.output_dir)
|
dirs = os.listdir(args.output_dir)
|
||||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||||
path = dirs[-1]
|
path = dirs[-1] if len(dirs) > 0 else None
|
||||||
accelerator.print(f"Resuming from checkpoint {path}")
|
|
||||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
|
||||||
global_step = int(path.split("-")[1])
|
|
||||||
|
|
||||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
if path is None:
|
||||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
accelerator.print(
|
||||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||||
|
)
|
||||||
|
args.resume_from_checkpoint = None
|
||||||
|
else:
|
||||||
|
accelerator.print(f"Resuming from checkpoint {path}")
|
||||||
|
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||||
|
global_step = int(path.split("-")[1])
|
||||||
|
|
||||||
|
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||||
|
first_epoch = global_step // num_update_steps_per_epoch
|
||||||
|
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
# Only show the progress bar once on each machine.
|
||||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
|
Loading…
Reference in New Issue