add dadapt_adan and fix bug in decay steps for te
This commit is contained in:
parent
4fd4e38bbd
commit
3bf95d4edc
|
@ -204,8 +204,8 @@ class EveryDreamOptimizer():
|
|||
lr_scheduler = get_scheduler(
|
||||
te_config.get("lr_scheduler", args.lr_scheduler),
|
||||
optimizer=self.optimizer_te,
|
||||
num_warmup_steps=te_config.get("lr_warmup_steps", None),
|
||||
num_training_steps=unet_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"]
|
||||
num_warmup_steps=int(te_config.get("lr_warmup_steps", None)) or unet_config["lr_warmup_steps"],
|
||||
num_training_steps=int(te_config.get("lr_decay_steps", None)) or unet_config["lr_decay_steps"]
|
||||
)
|
||||
ret_val.append(lr_scheduler)
|
||||
|
||||
|
@ -274,17 +274,21 @@ class EveryDreamOptimizer():
|
|||
curr_lr = args.lr
|
||||
d0 = 1e-6 # dadapt
|
||||
decouple = True # seems bad to turn off, dadapt_adam only
|
||||
momentum = 0.0 # dadapt_sgd only
|
||||
momentum = 0.0 # dadapt_sgd
|
||||
no_prox = False # ????, dadapt_adan
|
||||
growth_rate=float("inf") # dadapt
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
betas = local_optimizer_config["betas"] or betas
|
||||
epsilon = local_optimizer_config["epsilon"] or epsilon
|
||||
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
|
||||
optimizer_name = local_optimizer_config["optimizer"] or "adamw8bit"
|
||||
betas = local_optimizer_config.get("betas", betas)
|
||||
epsilon = local_optimizer_config.get("epsilon", epsilon)
|
||||
weight_decay = local_optimizer_config.get("weight_decay", weight_decay)
|
||||
no_prox = local_optimizer_config.get("no_prox", False)
|
||||
optimizer_name = local_optimizer_config.get("optimizer", "adamw8bit")
|
||||
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
||||
d0 = local_optimizer_config.get("d0", d0)
|
||||
decouple = local_optimizer_config.get("decouple", decouple)
|
||||
momentum = local_optimizer_config.get("momentum", momentum)
|
||||
growth_rate = local_optimizer_config.get("growth_rate", growth_rate)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
@ -335,6 +339,19 @@ class EveryDreamOptimizer():
|
|||
growth_rate=1e5,
|
||||
decouple=decouple,
|
||||
)
|
||||
elif optimizer_name == "dadapt_adan":
|
||||
opt_class = dadaptation.DAdaptAdan
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
no_prox=no_prox,
|
||||
weight_decay=weight_decay,
|
||||
eps=epsilon,
|
||||
d0=d0,
|
||||
log_every=args.log_step,
|
||||
growth_rate=growth_rate,
|
||||
)
|
||||
elif optimizer_name == "dadapt_lion":
|
||||
opt_class = dadaptation.DAdaptLion
|
||||
optimizer = opt_class(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
transformers==4.29.2
|
||||
diffusers[torch]==0.14.0
|
||||
diffusers[torch]==0.17.1
|
||||
pynvml==11.4.1
|
||||
bitsandbytes==0.38.1
|
||||
ftfy==6.1.1
|
||||
|
|
Loading…
Reference in New Issue