Allow textual_inversion_flax script to use save_steps and revision flag (#2075)
* Update textual_inversion_flax.py * Update textual_inversion_flax.py * Typo sorry. * Format source
This commit is contained in:
parent
b7b4683bdc
commit
f3f626d556
|
@ -121,6 +121,12 @@ def parse_args():
|
|||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
|
@ -136,6 +142,13 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
|
@ -420,9 +433,9 @@ def main():
|
|||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",revision=args.revision)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",revision=args.revision)
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",revision=args.revision)
|
||||
|
||||
# Create sampling rng
|
||||
rng = jax.random.PRNGKey(args.seed)
|
||||
|
@ -619,6 +632,12 @@ def main():
|
|||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
if global_step % args.save_steps == 0:
|
||||
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
|
||||
placeholder_token_id
|
||||
]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds}
|
||||
jnp.save(os.path.join(args.output_dir, "learned_embeds-"+str(global_step)+".npy"), learned_embeds_dict)
|
||||
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
|
||||
|
|
Loading…
Reference in New Issue