94 lines
4.2 KiB
Python
94 lines
4.2 KiB
Python
import torch
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
class AdamCM(Optimizer):
|
|
def __init__(self, params, lr=1e-6, betas=(0.9, 0.999), epsilon=1e-8,
|
|
weight_decay=0, buffer_capacity=10, decay_lambda=0.8):
|
|
if lr < 0.0 or lr is None:
|
|
raise ValueError(f"Invalid learning rate: {lr}")
|
|
if not 0.0 <= epsilon:
|
|
raise ValueError(f"Invalid epsilon value: {epsilon}")
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
|
|
defaults = dict(lr=lr, betas=betas, epsilon=epsilon,
|
|
weight_decay=weight_decay, buffer_capacity=buffer_capacity,
|
|
decay_lambda=decay_lambda)
|
|
super(AdamCM, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super(AdamCM, self).__setstate__(state)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
with torch.cuda.amp.autocast():
|
|
for group in self.param_groups:
|
|
for p in group['params']:
|
|
if p.grad is None:
|
|
continue
|
|
grad = p.grad.data
|
|
if grad.is_sparse:
|
|
raise RuntimeError('AdamCM does not support sparse gradients')
|
|
|
|
state = self.state[p]
|
|
|
|
# State initialization
|
|
if len(state) == 0:
|
|
state['step'] = 0
|
|
state['buffer'] = []
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
|
|
|
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
|
beta1, beta2 = group['betas']
|
|
|
|
state['step'] += 1
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
# Update buffer
|
|
priority = grad.norm()
|
|
if len(state['buffer']) < group['buffer_capacity']:
|
|
state['buffer'].append((priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone()))
|
|
else:
|
|
# Find and replace the gradient with the smallest priority
|
|
min_priority, min_idx = min((buf[0], idx) for idx, buf in enumerate(state['buffer']))
|
|
if priority > min_priority:
|
|
state['buffer'][min_idx] = (priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone())
|
|
|
|
# Decay priorities
|
|
#for i, buf in enumerate(state['buffer']):
|
|
# buf[0] *= group['decay_lambda']
|
|
|
|
# Aggregate momenta
|
|
critical_exp_avg = torch.zeros_like(exp_avg, dtype=torch.bfloat16)
|
|
critical_exp_avg_sq = torch.zeros_like(exp_avg_sq, dtype=torch.bfloat16)
|
|
|
|
for i, (priority, buf_exp_avg, buf_exp_avg_sq) in enumerate(state['buffer']):
|
|
decayed_priority = priority * group['decay_lambda']
|
|
critical_exp_avg.add_(buf_exp_avg)
|
|
critical_exp_avg_sq.add_(buf_exp_avg_sq)
|
|
state['buffer'][i] = (decayed_priority, buf_exp_avg, buf_exp_avg_sq)
|
|
|
|
denom = critical_exp_avg_sq.sqrt().add_(group['epsilon'])
|
|
|
|
step_size = group['lr']
|
|
if group['weight_decay'] != 0:
|
|
grad.add_(p.data, alpha=group['weight_decay'])
|
|
|
|
# Parameter update
|
|
p.data.addcdiv_(critical_exp_avg, denom, value=-step_size)
|
|
|
|
return loss
|