gg bug from pr that thought it was simplifying code but broke it

This commit is contained in:
Victor Hall 2023-11-15 16:19:08 -05:00
parent 097d864ef5
commit 840493037e
2 changed files with 97 additions and 4 deletions

93
optimizer/adamcm.py Normal file
View File

@ -0,0 +1,93 @@
import torch
from torch.optim.optimizer import Optimizer
class AdamCM(Optimizer):
def __init__(self, params, lr=1e-6, betas=(0.9, 0.999), epsilon=1e-8,
weight_decay=0, buffer_capacity=10, decay_lambda=0.8):
if lr < 0.0 or lr is None:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= epsilon:
raise ValueError(f"Invalid epsilon value: {epsilon}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, epsilon=epsilon,
weight_decay=weight_decay, buffer_capacity=buffer_capacity,
decay_lambda=decay_lambda)
super(AdamCM, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamCM, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
with torch.cuda.amp.autocast():
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('AdamCM does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['buffer'] = []
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.bfloat16)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Update buffer
priority = grad.norm()
if len(state['buffer']) < group['buffer_capacity']:
state['buffer'].append((priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone()))
else:
# Find and replace the gradient with the smallest priority
min_priority, min_idx = min((buf[0], idx) for idx, buf in enumerate(state['buffer']))
if priority > min_priority:
state['buffer'][min_idx] = (priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone())
# Decay priorities
#for i, buf in enumerate(state['buffer']):
# buf[0] *= group['decay_lambda']
# Aggregate momenta
critical_exp_avg = torch.zeros_like(exp_avg, dtype=torch.bfloat16)
critical_exp_avg_sq = torch.zeros_like(exp_avg_sq, dtype=torch.bfloat16)
for i, (priority, buf_exp_avg, buf_exp_avg_sq) in enumerate(state['buffer']):
decayed_priority = priority * group['decay_lambda']
critical_exp_avg.add_(buf_exp_avg)
critical_exp_avg_sq.add_(buf_exp_avg_sq)
state['buffer'][i] = (decayed_priority, buf_exp_avg, buf_exp_avg_sq)
denom = critical_exp_avg_sq.sqrt().add_(group['epsilon'])
step_size = group['lr']
if group['weight_decay'] != 0:
grad.add_(p.data, alpha=group['weight_decay'])
# Parameter update
p.data.addcdiv_(critical_exp_avg, denom, value=-step_size)
return loss

View File

@ -232,14 +232,14 @@ class EveryDreamOptimizer():
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", args.lr_warmup_steps)
base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps
base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps
base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler
te_config["lr"] = te_config.get("lr", None) or base_config["lr"]
te_config["optimizer"] = te_config.get("optimizer", None) or base_config["optimizer"]
te_config["lr_scheduler"] = te_config.get("lr_scheduler", None) or base_config["lr_scheduler"]
te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", base_config["lr_warmup_steps"])
te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", None) or base_config["lr_warmup_steps"]
te_config["lr_decay_steps"] = te_config.get("lr_decay_steps", None) or base_config["lr_decay_steps"]
te_config["weight_decay"] = te_config.get("weight_decay", None) or base_config["weight_decay"]
te_config["betas"] = te_config.get("betas", None) or base_config["betas"]
@ -257,8 +257,8 @@ class EveryDreamOptimizer():
lr_scheduler = get_scheduler(
te_config.get("lr_scheduler", args.lr_scheduler),
optimizer=self.optimizer_te,
num_warmup_steps=int(te_config.get("lr_warmup_steps", None) or unet_config["lr_warmup_steps"]),
num_training_steps=int(te_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"])
num_warmup_steps=int(te_config.get("lr_warmup_steps", None) or unet_config.get("lr_warmup_steps",0)),
num_training_steps=int(te_config.get("lr_decay_steps", None) or unet_config.get("lr_decay_steps",1e9))
)
ret_val.append(lr_scheduler)