fix device mismatch with loss_scale

This commit is contained in:
Damian Stewart 2023-11-01 09:29:41 +01:00 committed by GitHub
parent a7343ad190
commit c485d4ea60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -963,7 +963,7 @@ def main(args):
loss_scale = loss_scale * mse_loss_weights
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
loss = loss.mean()
return model_pred, target, loss