last tweak

This commit is contained in:
Victor Hall 2023-01-03 15:17:24 -05:00
parent 6eccbe0ecc
commit a26c127740
1 changed files with 3 additions and 1 deletions

View File

@ -651,6 +651,9 @@ def main(args):
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
if batch["runt_size"] > 0:
grad_scale = batch["runt_size"] / args.batch_size
with torch.no_grad(): # not required? just in case for now, needs more testing
@ -661,7 +664,6 @@ def main(args):
for param in text_encoder.parameters():
if param.grad is not None:
param.grad *= grad_scale
del target, model_pred
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)