cruft left from experiment
This commit is contained in:
parent
605716a646
commit
d1bc94fe3e
8
train.py
8
train.py
|
@ -216,7 +216,7 @@ def setup_args(args):
|
||||||
"""
|
"""
|
||||||
if args.disable_amp:
|
if args.disable_amp:
|
||||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}")
|
logging.warning(f"{Fore.LIGHTYELLOW_EX} Disabling AMP, not recommended.{Style.RESET_ALL}")
|
||||||
args.amp= False
|
args.amp = False
|
||||||
else:
|
else:
|
||||||
args.amp = True
|
args.amp = True
|
||||||
|
|
||||||
|
@ -811,6 +811,12 @@ def main(args):
|
||||||
|
|
||||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# loss_l1 = F.l1_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
# log_writer.add_scalar(tag="loss/l1", scalar_value=loss_l1, global_step=global_step)
|
||||||
|
# loss_hinge = F.hinge_embedding_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
# log_writer.add_scalar(tag="loss/hinge", scalar_value=loss_hinge, global_step=global_step)
|
||||||
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
Loading…
Reference in New Issue