add dadapt_adan and fix bug in decay steps for te
This commit is contained in:
parent
4fd4e38bbd
commit
3bf95d4edc
|
@ -204,8 +204,8 @@ class EveryDreamOptimizer():
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
te_config.get("lr_scheduler", args.lr_scheduler),
|
te_config.get("lr_scheduler", args.lr_scheduler),
|
||||||
optimizer=self.optimizer_te,
|
optimizer=self.optimizer_te,
|
||||||
num_warmup_steps=te_config.get("lr_warmup_steps", None),
|
num_warmup_steps=int(te_config.get("lr_warmup_steps", None)) or unet_config["lr_warmup_steps"],
|
||||||
num_training_steps=unet_config.get("lr_decay_steps", None) or unet_config["lr_decay_steps"]
|
num_training_steps=int(te_config.get("lr_decay_steps", None)) or unet_config["lr_decay_steps"]
|
||||||
)
|
)
|
||||||
ret_val.append(lr_scheduler)
|
ret_val.append(lr_scheduler)
|
||||||
|
|
||||||
|
@ -272,19 +272,23 @@ class EveryDreamOptimizer():
|
||||||
|
|
||||||
default_lr = 1e-6
|
default_lr = 1e-6
|
||||||
curr_lr = args.lr
|
curr_lr = args.lr
|
||||||
d0 = 1e-6 # dadapt
|
d0 = 1e-6 # dadapt
|
||||||
decouple = True # seems bad to turn off, dadapt_adam only
|
decouple = True # seems bad to turn off, dadapt_adam only
|
||||||
momentum = 0.0 # dadapt_sgd only
|
momentum = 0.0 # dadapt_sgd
|
||||||
|
no_prox = False # ????, dadapt_adan
|
||||||
|
growth_rate=float("inf") # dadapt
|
||||||
|
|
||||||
if local_optimizer_config is not None:
|
if local_optimizer_config is not None:
|
||||||
betas = local_optimizer_config["betas"] or betas
|
betas = local_optimizer_config.get("betas", betas)
|
||||||
epsilon = local_optimizer_config["epsilon"] or epsilon
|
epsilon = local_optimizer_config.get("epsilon", epsilon)
|
||||||
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
|
weight_decay = local_optimizer_config.get("weight_decay", weight_decay)
|
||||||
optimizer_name = local_optimizer_config["optimizer"] or "adamw8bit"
|
no_prox = local_optimizer_config.get("no_prox", False)
|
||||||
|
optimizer_name = local_optimizer_config.get("optimizer", "adamw8bit")
|
||||||
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
||||||
d0 = local_optimizer_config.get("d0", d0)
|
d0 = local_optimizer_config.get("d0", d0)
|
||||||
decouple = local_optimizer_config.get("decouple", decouple)
|
decouple = local_optimizer_config.get("decouple", decouple)
|
||||||
momentum = local_optimizer_config.get("momentum", momentum)
|
momentum = local_optimizer_config.get("momentum", momentum)
|
||||||
|
growth_rate = local_optimizer_config.get("growth_rate", growth_rate)
|
||||||
if args.lr is not None:
|
if args.lr is not None:
|
||||||
curr_lr = args.lr
|
curr_lr = args.lr
|
||||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||||
|
@ -335,6 +339,19 @@ class EveryDreamOptimizer():
|
||||||
growth_rate=1e5,
|
growth_rate=1e5,
|
||||||
decouple=decouple,
|
decouple=decouple,
|
||||||
)
|
)
|
||||||
|
elif optimizer_name == "dadapt_adan":
|
||||||
|
opt_class = dadaptation.DAdaptAdan
|
||||||
|
optimizer = opt_class(
|
||||||
|
itertools.chain(parameters),
|
||||||
|
lr=curr_lr,
|
||||||
|
betas=(betas[0], betas[1]),
|
||||||
|
no_prox=no_prox,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
eps=epsilon,
|
||||||
|
d0=d0,
|
||||||
|
log_every=args.log_step,
|
||||||
|
growth_rate=growth_rate,
|
||||||
|
)
|
||||||
elif optimizer_name == "dadapt_lion":
|
elif optimizer_name == "dadapt_lion":
|
||||||
opt_class = dadaptation.DAdaptLion
|
opt_class = dadaptation.DAdaptLion
|
||||||
optimizer = opt_class(
|
optimizer = opt_class(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
torchvision==0.15.2
|
torchvision==0.15.2
|
||||||
transformers==4.29.2
|
transformers==4.29.2
|
||||||
diffusers[torch]==0.14.0
|
diffusers[torch]==0.17.1
|
||||||
pynvml==11.4.1
|
pynvml==11.4.1
|
||||||
bitsandbytes==0.38.1
|
bitsandbytes==0.38.1
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
|
|
Loading…
Reference in New Issue