diff --git a/optimizer/adamcm.py b/optimizer/adamcm.py new file mode 100644 index 0000000..d5f0a57 --- /dev/null +++ b/optimizer/adamcm.py @@ -0,0 +1,93 @@ +import torch +from torch.optim.optimizer import Optimizer + +class AdamCM(Optimizer): + def __init__(self, params, lr=1e-6, betas=(0.9, 0.999), epsilon=1e-8, + weight_decay=0, buffer_capacity=10, decay_lambda=0.8): + if lr < 0.0 or lr is None: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= epsilon: + raise ValueError(f"Invalid epsilon value: {epsilon}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + + defaults = dict(lr=lr, betas=betas, epsilon=epsilon, + weight_decay=weight_decay, buffer_capacity=buffer_capacity, + decay_lambda=decay_lambda) + super(AdamCM, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamCM, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + with torch.cuda.amp.autocast(): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('AdamCM does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['buffer'] = [] + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data, dtype=torch.bfloat16) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=torch.bfloat16) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Update buffer + priority = grad.norm() + if len(state['buffer']) < group['buffer_capacity']: + state['buffer'].append((priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone())) + else: + # Find and replace the gradient with the smallest priority + min_priority, min_idx = min((buf[0], idx) for idx, buf in enumerate(state['buffer'])) + if priority > min_priority: + state['buffer'][min_idx] = (priority, exp_avg.bfloat16().clone(), exp_avg_sq.bfloat16().clone()) + + # Decay priorities + #for i, buf in enumerate(state['buffer']): + # buf[0] *= group['decay_lambda'] + + # Aggregate momenta + critical_exp_avg = torch.zeros_like(exp_avg, dtype=torch.bfloat16) + critical_exp_avg_sq = torch.zeros_like(exp_avg_sq, dtype=torch.bfloat16) + + for i, (priority, buf_exp_avg, buf_exp_avg_sq) in enumerate(state['buffer']): + decayed_priority = priority * group['decay_lambda'] + critical_exp_avg.add_(buf_exp_avg) + critical_exp_avg_sq.add_(buf_exp_avg_sq) + state['buffer'][i] = (decayed_priority, buf_exp_avg, buf_exp_avg_sq) + + denom = critical_exp_avg_sq.sqrt().add_(group['epsilon']) + + step_size = group['lr'] + if group['weight_decay'] != 0: + grad.add_(p.data, alpha=group['weight_decay']) + + # Parameter update + p.data.addcdiv_(critical_exp_avg, denom, value=-step_size) + + return loss diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index 68ef581..bf2a398 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -232,14 +232,14 @@ class EveryDreamOptimizer(): base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler - base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", args.lr_warmup_steps) + base_config["lr_warmup_steps"] = base_config.get("lr_warmup_steps", None) or args.lr_warmup_steps base_config["lr_decay_steps"] = base_config.get("lr_decay_steps", None) or args.lr_decay_steps base_config["lr_scheduler"] = base_config.get("lr_scheduler", None) or args.lr_scheduler te_config["lr"] = te_config.get("lr", None) or base_config["lr"] te_config["optimizer"] = te_config.get("optimizer", None) or base_config["optimizer"] te_config["lr_scheduler"] = te_config.get("lr_scheduler", None) or base_config["lr_scheduler"] - te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", base_config["lr_warmup_steps"]) + te_config["lr_warmup_steps"] = te_config.get("lr_warmup_steps", None) or base_config["lr_warmup_steps"] te_config["lr_decay_steps"] = te_config.get("lr_decay_steps", None) or base_config["lr_decay_steps"] te_config["weight_decay"] = te_config.get("weight_decay", None) or base_config["weight_decay"] te_config["betas"] = te_config.get("betas", None) or base_config["betas"] @@ -257,8 +257,8 @@ class EveryDreamOptimizer(): lr_scheduler = get_scheduler( te_config.get("lr_scheduler", args.lr_scheduler), optimizer=self.optimizer_te, - num_warmup_steps=int(te_config.get("lr_warmup_steps", None) or unet_config["lr_warmup_steps"]), - num_training_steps=int(te_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"]) + num_warmup_steps=int(te_config.get("lr_warmup_steps", None) or unet_config.get("lr_warmup_steps",0)), + num_training_steps=int(te_config.get("lr_decay_steps", None) or unet_config.get("lr_decay_steps",1e9)) ) ret_val.append(lr_scheduler)