[textual_inversion] Fix resuming state when using gradient checkpointing (#2072)

* Fix resuming state when using gradient checkpointing.

Also, allow --resume_from_checkpoint to be used when the checkpoint does
not yet exist (a normal training run will be started).

* style
This commit is contained in:
Pedro Cuenca 2023-01-24 10:25:41 +01:00 committed by GitHub
parent 7d8b4f7f8e
commit f4dddaf5ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 8 deletions

View File

@ -597,7 +597,7 @@ def main():
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
# For mixed precision training we cast the text_encoder and vae weights to half-precision # For mixed precision training we cast the unet and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required. # as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32 weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16": if accelerator.mixed_precision == "fp16":
@ -643,14 +643,21 @@ def main():
dirs = os.listdir(args.output_dir) dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] path = dirs[-1] if len(dirs) > 0 else None
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps if path is None:
first_epoch = resume_global_step // num_update_steps_per_epoch accelerator.print(
resume_step = resume_global_step % num_update_steps_per_epoch f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine. # Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)