From 7f31142c2eaaaec11f7e7c461f300caf9c1c5661 Mon Sep 17 00:00:00 2001 From: Isamu Isozaki Date: Thu, 29 Sep 2022 00:26:02 +0900 Subject: [PATCH] Added script to save during textual inversion training. Issue 524 (#645) * Added script to save during training * Suggested changes --- .../textual_inversion/textual_inversion.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 53b4cf2f..253063e7 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -29,8 +29,21 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = get_logger(__name__) +def save_progress(text_encoder, placeholder_token_id, accelerator, args): + logger.info("Saving embeddings") + learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -542,6 +555,8 @@ def main(): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 + if global_step % args.save_steps == 0: + save_progress(text_encoder, placeholder_token_id, accelerator, args) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -567,9 +582,7 @@ def main(): ) pipeline.save_pretrained(args.output_dir) # Also save the newly trained embeddings - learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] - learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} - torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + save_progress(text_encoder, placeholder_token_id, accelerator, args) if args.push_to_hub: repo.push_to_hub(