last tweak
This commit is contained in:
parent
6eccbe0ecc
commit
a26c127740
4
train.py
4
train.py
|
@ -651,6 +651,9 @@ def main(args):
|
||||||
del timesteps, encoder_hidden_states, noisy_latents
|
del timesteps, encoder_hidden_states, noisy_latents
|
||||||
#with autocast(enabled=args.amp):
|
#with autocast(enabled=args.amp):
|
||||||
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
|
del target, model_pred
|
||||||
|
|
||||||
if batch["runt_size"] > 0:
|
if batch["runt_size"] > 0:
|
||||||
grad_scale = batch["runt_size"] / args.batch_size
|
grad_scale = batch["runt_size"] / args.batch_size
|
||||||
with torch.no_grad(): # not required? just in case for now, needs more testing
|
with torch.no_grad(): # not required? just in case for now, needs more testing
|
||||||
|
@ -661,7 +664,6 @@ def main(args):
|
||||||
for param in text_encoder.parameters():
|
for param in text_encoder.parameters():
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad *= grad_scale
|
param.grad *= grad_scale
|
||||||
del target, model_pred
|
|
||||||
|
|
||||||
if args.clip_grad_norm is not None:
|
if args.clip_grad_norm is not None:
|
||||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||||
|
|
Loading…
Reference in New Issue