make sure main lr arg overrides optimizer.json
This commit is contained in:
parent
b5fad8c675
commit
600eaa404d
4
train.py
4
train.py
|
@ -521,6 +521,8 @@ def main(args):
|
|||
weight_decay = optimizer_config["weight_decay"]
|
||||
optimizer_name = optimizer_config["optimizer"]
|
||||
curr_lr = optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
|
@ -760,8 +762,6 @@ def main(args):
|
|||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
|
||||
#del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
|
Loading…
Reference in New Issue