From b212b94d13814501c902060b79dd75988551afb5 Mon Sep 17 00:00:00 2001 From: Gabriel Roldan Date: Mon, 13 Nov 2023 12:17:56 -0300 Subject: [PATCH] Fix device mismatch when using min snr gamma option --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 52c2ba6..998bac1 100644 --- a/train.py +++ b/train.py @@ -974,7 +974,7 @@ def main(args): / snr ) mse_loss_weights[snr == 0] = 1.0 - loss_scale = loss_scale * mse_loss_weights + loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device) loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)