Fix device mismatch when using min snr gamma option
This commit is contained in:
parent
6ea721887c
commit
b212b94d13
2
train.py
2
train.py
|
@ -974,7 +974,7 @@ def main(args):
|
|||
/ snr
|
||||
)
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
loss_scale = loss_scale * mse_loss_weights
|
||||
loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
|
||||
|
|
Loading…
Reference in New Issue