diff --git a/train.py b/train.py index a2ece9a..750ff11 100644 --- a/train.py +++ b/train.py @@ -651,6 +651,9 @@ def main(args): del timesteps, encoder_hidden_states, noisy_latents #with autocast(enabled=args.amp): loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean") + + del target, model_pred + if batch["runt_size"] > 0: grad_scale = batch["runt_size"] / args.batch_size with torch.no_grad(): # not required? just in case for now, needs more testing @@ -661,7 +664,6 @@ def main(args): for param in text_encoder.parameters(): if param.grad is not None: param.grad *= grad_scale - del target, model_pred if args.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)