diff --git a/train.py b/train.py index 1949fc8..8beece9 100644 --- a/train.py +++ b/train.py @@ -216,7 +216,7 @@ def setup_args(args): """ if args.disable_amp: logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}") - args.amp= False + args.amp = False else: args.amp = True @@ -811,6 +811,12 @@ def main(args): model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio) + # with torch.no_grad(): + # loss_l1 = F.l1_loss(model_pred.float(), target.float(), reduction="mean") + # log_writer.add_scalar(tag="loss/l1", scalar_value=loss_l1, global_step=global_step) + # loss_hinge = F.hinge_embedding_loss(model_pred.float(), target.float(), reduction="mean") + # log_writer.add_scalar(tag="loss/hinge", scalar_value=loss_hinge, global_step=global_step) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") del target, model_pred