last tweak
This commit is contained in:
parent
6eccbe0ecc
commit
a26c127740
4
train.py
4
train.py
|
@ -651,6 +651,9 @@ def main(args):
|
|||
del timesteps, encoder_hidden_states, noisy_latents
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
grad_scale = batch["runt_size"] / args.batch_size
|
||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||
|
@ -661,7 +664,6 @@ def main(args):
|
|||
for param in text_encoder.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad *= grad_scale
|
||||
del target, model_pred
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
|
Loading…
Reference in New Issue