EveryDream2trainer/optimizer/adamcm.py

94 lines
4.2 KiB
Python

import torch
from torch.optim.optimizer import Optimizer
class AdamCM(Optimizer):
def __init__(self, params, lr=1e-6, betas=(0.9, 0.999), epsilon=1e-8,
weight_decay=0, buffer_capacity=10, decay_lambda=0.8):
if lr < 0.0 or lr is None:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= epsilon:
raise ValueError(f"Invalid epsilon value: {epsilon}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, epsilon=epsilon,
weight_decay=weight_decay, buffer_capacity=buffer_capacity,
decay_lambda=decay_lambda)
super(AdamCM, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamCM, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
with torch.cuda.amp.autocast():
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('AdamCM does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['buffer'] = []
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Update buffer
priority = grad.norm()
if len(state['buffer']) < group['buffer_capacity']:
state['buffer'].append((priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone()))
else:
# Find and replace the gradient with the smallest priority
min_priority, min_idx = min((buf[0], idx) for idx, buf in enumerate(state['buffer']))
if priority > min_priority:
state['buffer'][min_idx] = (priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone())
# Decay priorities
#for i, buf in enumerate(state['buffer']):
# buf[0] *= group['decay_lambda']
# Aggregate momenta
critical_exp_avg = torch.zeros_like(exp_avg, dtype=torch.bfloat16)
critical_exp_avg_sq = torch.zeros_like(exp_avg_sq, dtype=torch.bfloat16)
for i, (priority, buf_exp_avg, buf_exp_avg_sq) in enumerate(state['buffer']):
decayed_priority = priority * group['decay_lambda']
critical_exp_avg.add_(buf_exp_avg)
critical_exp_avg_sq.add_(buf_exp_avg_sq)
state['buffer'][i] = (decayed_priority, buf_exp_avg, buf_exp_avg_sq)
denom = critical_exp_avg_sq.sqrt().add_(group['epsilon'])
step_size = group['lr']
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
# Parameter update
p.data.addcdiv_(critical_exp_avg, denom, value=-step_size)
return loss