Fix device mismatch when using min snr gamma option
This commit is contained in:
parent
6ea721887c
commit
b212b94d13
2
train.py
2
train.py
|
@ -974,7 +974,7 @@ def main(args):
|
||||||
/ snr
|
/ snr
|
||||||
)
|
)
|
||||||
mse_loss_weights[snr == 0] = 1.0
|
mse_loss_weights[snr == 0] = 1.0
|
||||||
loss_scale = loss_scale * mse_loss_weights
|
loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device)
|
||||||
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
|
||||||
|
|
Loading…
Reference in New Issue