add text encoder LR setting to optimizer.json

This commit is contained in:
Damian Stewart 2023-03-02 00:13:43 +01:00
parent ba87b0cae1
commit c82664b3f3
2 changed files with 32 additions and 14 deletions

View File

@ -5,11 +5,13 @@
"lr": "learning rate, if null wil use CLI or main JSON config value",
"betas": "exponential decay rates for the moment estimates",
"epsilon": "value added to denominator for numerical stability, unused for lion",
"weight_decay": "weight decay (L2 penalty)"
"weight_decay": "weight decay (L2 penalty)",
"text_encoder_lr_scale": "if set, scale the text encoder's LR by this much relative to the unet LR"
},
"optimizer": "adamw8bit",
"lr": 1e-6,
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.010
"weight_decay": 0.010,
"text_encoder_lr_scale": 1.0
}

View File

@ -477,16 +477,6 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
optimizer_config = None
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
@ -514,6 +504,7 @@ def main(args):
default_lr = 1e-6
curr_lr = args.lr
text_encoder_lr_scale = 1.0
if optimizer_config is not None:
betas = optimizer_config["betas"]
@ -524,12 +515,30 @@ def main(args):
if args.lr is not None:
curr_lr = args.lr
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
if text_encoder_lr_scale != 1.0:
print(f" * Using text encoder LR scale {text_encoder_lr_scale}")
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
if curr_lr is None:
curr_lr = default_lr
logging.warning(f"No LR setting found, defaulting to {default_lr}")
text_encoder_lr = curr_lr * text_encoder_lr_scale
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters())
elif args.disable_unet_training:
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
params_to_train = itertools.chain(text_encoder.parameters())
else:
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = [{'params': unet.parameters()},
{'params': text_encoder.parameters(), 'lr': text_encoder_lr}]
if optimizer_name:
if optimizer_name == "lion":
from lion_pytorch import Lion
@ -803,7 +812,14 @@ def main(args):
loss_local = sum(loss_log_step) / len(loss_log_step)
loss_log_step = []
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
if args.disable_textenc_training or args.disable_unet_training:
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step)
else:
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={
'unet': curr_lr,
'text encoder': curr_text_encoder_lr
}, global_step = global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)