Fix device mismatch when using min snr gamma option

This commit is contained in:
Gabriel Roldan 2023-11-13 12:17:56 -03:00
parent 6ea721887c
commit b212b94d13
No known key found for this signature in database
GPG Key ID: 6FAD6D4A395EB862
1 changed files with 1 additions and 1 deletions

View File

@ -974,7 +974,7 @@ def main(args):
/ snr / snr
) )
mse_loss_weights[snr == 0] = 1.0 mse_loss_weights[snr == 0] = 1.0
loss_scale = loss_scale * mse_loss_weights loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device) loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)