add dadapt_adan and fix bug in decay steps for te

This commit is contained in:
Victor Hall 2023-06-24 14:41:16 -04:00
parent 4fd4e38bbd
commit 3bf95d4edc
2 changed files with 26 additions and 9 deletions

View File

@ -204,8 +204,8 @@ class EveryDreamOptimizer():
lr_scheduler = get_scheduler(
te_config.get("lr_scheduler", args.lr_scheduler),
optimizer=self.optimizer_te,
num_warmup_steps=te_config.get("lr_warmup_steps", None),
num_training_steps=unet_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"]
num_warmup_steps=int(te_config.get("lr_warmup_steps", None)) or unet_config["lr_warmup_steps"],
num_training_steps=int(te_config.get("lr_decay_steps", None)) or unet_config["lr_decay_steps"]
)
ret_val.append(lr_scheduler)
@ -272,19 +272,23 @@ class EveryDreamOptimizer():
default_lr = 1e-6
curr_lr = args.lr
d0 = 1e-6 # dadapt
d0 = 1e-6 # dadapt
decouple = True # seems bad to turn off, dadapt_adam only
momentum = 0.0 # dadapt_sgd only
momentum = 0.0 # dadapt_sgd
no_prox = False # ????, dadapt_adan
growth_rate=float("inf") # dadapt
if local_optimizer_config is not None:
betas = local_optimizer_config["betas"] or betas
epsilon = local_optimizer_config["epsilon"] or epsilon
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
optimizer_name = local_optimizer_config["optimizer"] or "adamw8bit"
betas = local_optimizer_config.get("betas", betas)
epsilon = local_optimizer_config.get("epsilon", epsilon)
weight_decay = local_optimizer_config.get("weight_decay", weight_decay)
no_prox = local_optimizer_config.get("no_prox", False)
optimizer_name = local_optimizer_config.get("optimizer", "adamw8bit")
curr_lr = local_optimizer_config.get("lr", curr_lr)
d0 = local_optimizer_config.get("d0", d0)
decouple = local_optimizer_config.get("decouple", decouple)
momentum = local_optimizer_config.get("momentum", momentum)
growth_rate = local_optimizer_config.get("growth_rate", growth_rate)
if args.lr is not None:
curr_lr = args.lr
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
@ -335,6 +339,19 @@ class EveryDreamOptimizer():
growth_rate=1e5,
decouple=decouple,
)
elif optimizer_name == "dadapt_adan":
opt_class = dadaptation.DAdaptAdan
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
betas=(betas[0], betas[1]),
no_prox=no_prox,
weight_decay=weight_decay,
eps=epsilon,
d0=d0,
log_every=args.log_step,
growth_rate=growth_rate,
)
elif optimizer_name == "dadapt_lion":
opt_class = dadaptation.DAdaptLion
optimizer = opt_class(

View File

@ -1,7 +1,7 @@
torch==2.0.1
torchvision==0.15.2
transformers==4.29.2
diffusers[torch]==0.14.0
diffusers[torch]==0.17.1
pynvml==11.4.1
bitsandbytes==0.38.1
ftfy==6.1.1