diff --git a/optimizer/adacoor.py b/optimizer/adacoor.py index 9c295e7..f72f95e 100644 --- a/optimizer/adacoor.py +++ b/optimizer/adacoor.py @@ -32,6 +32,6 @@ class AdaCoor(torch.optim.Optimizer): gt_hat = (epsilon * p.grad.data).to(dtype=torch.float32, device=p.device) denom = vt.sqrt().add_(group['epsilon']).to(dtype=p.dtype, device=p.device) - p.data.addcdiv_(gt_hat, denom, value=-1) - + p.data.addcdiv_(gt_hat, denom, value=-1.0) + return loss \ No newline at end of file