diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index bb9884b3..e184eb0c 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -121,6 +121,12 @@ def parse_args(): default=5000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) parser.add_argument( "--learning_rate", type=float, @@ -136,6 +142,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--lr_scheduler", type=str, @@ -420,9 +433,9 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision) + vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision) + unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision) # Create sampling rng rng = jax.random.PRNGKey(args.seed) @@ -619,6 +632,12 @@ def main(): if global_step >= args.max_train_steps: break + if global_step % args.save_steps == 0: + learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][ + placeholder_token_id + ] + learned_embeds_dict = {args.placeholder_token: learned_embeds} + jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict) train_metric = jax_utils.unreplicate(train_metric)