add text encoder LR setting to optimizer.json
This commit is contained in:
parent
ba87b0cae1
commit
c82664b3f3
|
@ -5,11 +5,13 @@
|
|||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
"weight_decay": "weight decay (L2 penalty)",
|
||||
"text_encoder_lr_scale": "if set, scale the text encoder's LR by this much relative to the unet LR"
|
||||
},
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-6,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010
|
||||
"weight_decay": 0.010,
|
||||
"text_encoder_lr_scale": 1.0
|
||||
}
|
||||
|
|
36
train.py
36
train.py
|
@ -477,16 +477,6 @@ def main(args):
|
|||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
optimizer_config = None
|
||||
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
||||
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
||||
|
@ -514,6 +504,7 @@ def main(args):
|
|||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if optimizer_config is not None:
|
||||
betas = optimizer_config["betas"]
|
||||
|
@ -524,12 +515,30 @@ def main(args):
|
|||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||
if text_encoder_lr_scale != 1.0:
|
||||
print(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = [{'params': unet.parameters()},
|
||||
{'params': text_encoder.parameters(), 'lr': text_encoder_lr}]
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
|
@ -803,7 +812,14 @@ def main(args):
|
|||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
if args.disable_textenc_training or args.disable_unet_training:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step = global_step)
|
||||
else:
|
||||
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||
log_writer.add_scalars(main_tag="hyperparamater/lr", tag_scalar_dict={
|
||||
'unet': curr_lr,
|
||||
'text encoder': curr_text_encoder_lr
|
||||
}, global_step = global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
|
|
Loading…
Reference in New Issue