From 3bf95d4edc12ff1f48bc9cd72e6bcac7ef4fe6fc Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sat, 24 Jun 2023 14:41:16 -0400 Subject: [PATCH] add dadapt_adan and fix bug in decay steps for te --- optimizer/optimizers.py | 33 +++++++++++++++++++++++++-------- requirements.txt | 2 +- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index 682fc56..543f0eb 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -204,8 +204,8 @@ class EveryDreamOptimizer(): lr_scheduler = get_scheduler( te_config.get("lr_scheduler", args.lr_scheduler), optimizer=self.optimizer_te, - num_warmup_steps=te_config.get("lr_warmup_steps", None), - num_training_steps=unet_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"] + num_warmup_steps=int(te_config.get("lr_warmup_steps", None)) or unet_config["lr_warmup_steps"], + num_training_steps=int(te_config.get("lr_decay_steps", None)) or unet_config["lr_decay_steps"] ) ret_val.append(lr_scheduler) @@ -272,19 +272,23 @@ class EveryDreamOptimizer(): default_lr = 1e-6 curr_lr = args.lr - d0 = 1e-6 # dadapt + d0 = 1e-6 # dadapt decouple = True # seems bad to turn off, dadapt_adam only - momentum = 0.0 # dadapt_sgd only + momentum = 0.0 # dadapt_sgd + no_prox = False # ????, dadapt_adan + growth_rate=float("inf") # dadapt if local_optimizer_config is not None: - betas = local_optimizer_config["betas"] or betas - epsilon = local_optimizer_config["epsilon"] or epsilon - weight_decay = local_optimizer_config["weight_decay"] or weight_decay - optimizer_name = local_optimizer_config["optimizer"] or "adamw8bit" + betas = local_optimizer_config.get("betas", betas) + epsilon = local_optimizer_config.get("epsilon", epsilon) + weight_decay = local_optimizer_config.get("weight_decay", weight_decay) + no_prox = local_optimizer_config.get("no_prox", False) + optimizer_name = local_optimizer_config.get("optimizer", "adamw8bit") curr_lr = local_optimizer_config.get("lr", curr_lr) d0 = local_optimizer_config.get("d0", d0) decouple = local_optimizer_config.get("decouple", decouple) momentum = local_optimizer_config.get("momentum", momentum) + growth_rate = local_optimizer_config.get("growth_rate", growth_rate) if args.lr is not None: curr_lr = args.lr logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}") @@ -335,6 +339,19 @@ class EveryDreamOptimizer(): growth_rate=1e5, decouple=decouple, ) + elif optimizer_name == "dadapt_adan": + opt_class = dadaptation.DAdaptAdan + optimizer = opt_class( + itertools.chain(parameters), + lr=curr_lr, + betas=(betas[0], betas[1]), + no_prox=no_prox, + weight_decay=weight_decay, + eps=epsilon, + d0=d0, + log_every=args.log_step, + growth_rate=growth_rate, + ) elif optimizer_name == "dadapt_lion": opt_class = dadaptation.DAdaptLion optimizer = opt_class( diff --git a/requirements.txt b/requirements.txt index a8f72ec..31a7963 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch==2.0.1 torchvision==0.15.2 transformers==4.29.2 -diffusers[torch]==0.14.0 +diffusers[torch]==0.17.1 pynvml==11.4.1 bitsandbytes==0.38.1 ftfy==6.1.1