fix device mismatch with loss_scale
This commit is contained in:
parent
a7343ad190
commit
c485d4ea60
2
train.py
2
train.py
|
@ -963,7 +963,7 @@ def main(args):
|
|||
loss_scale = loss_scale * mse_loss_weights
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
|
||||
loss = loss.mean()
|
||||
|
||||
return model_pred, target, loss
|
||||
|
|
Loading…
Reference in New Issue