37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
import torch
|
|
|
|
class AdaCoor(torch.optim.Optimizer):
|
|
def __init__(self, params, eps=1e-8, *args, **kwargs):
|
|
defaults = dict(epsilon=eps, lr=1)
|
|
super(AdaCoor, self).__init__(params, defaults)
|
|
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
with torch.no_grad():
|
|
|
|
# Initialize epsilon as a tensor
|
|
epsilon = torch.tensor([group['epsilon']], dtype=torch.bfloat16, device=next(iter(group['params'])).device)
|
|
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
|
|
state = self.state[p]
|
|
|
|
# Initialize state variable for vt
|
|
if 'vt' not in state:
|
|
state['vt'] = torch.zeros_like(p.data, device=p.device).to(dtype=torch.bfloat16, device=p.device)
|
|
|
|
vt = state['vt']
|
|
vt.add_((epsilon * p.grad.data ** 2).to(dtype=torch.bfloat16, device=p.device))
|
|
|
|
gt_hat = (epsilon * p.grad.data).to(dtype=torch.float32, device=p.device)
|
|
|
|
denom = vt.sqrt().add_(group['epsilon']).to(dtype=p.dtype, device=p.device)
|
|
p.data.addcdiv_(gt_hat, denom, value=-1.0)
|
|
|
|
return loss |