more wip optimizer splitting

This commit is contained in:
Victor Hall 2023-04-30 09:28:55 -04:00
parent 72a47741f0
commit 970065c206
4 changed files with 66 additions and 49 deletions

View File

@ -16,7 +16,7 @@
"unet": {
"optimizer": "adamw8bit",
"lr": 1e-6,
"lr_scheduler": null,
"lr_scheduler": "constant",
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.010

View File

@ -24,6 +24,7 @@ class EveryDreamOptimizer():
unet: unet model
"""
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
print(f"\noptimizer_config: \n{optimizer_config}\n")
self.grad_accum = args.grad_accum
self.clip_grad_norm = args.clip_grad_norm
self.text_encoder_params = text_encoder_params
@ -34,14 +35,6 @@ class EveryDreamOptimizer():
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
self.unet_config = optimizer_config.get("unet", {})
if args.lr is not None:
self.unet_config["lr"] = args.lr
self.te_config = optimizer_config.get("text_encoder", {})
if self.te_config.get("lr", None) is None:
self.te_config["lr"] = self.unet_config["lr"]
te_scale = self.optimizer_config.get("text_encoder_lr_scale", None)
if te_scale is not None:
self.te_config["lr"] = self.unet_config["lr"] * te_scale
optimizer_te_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_TE_STATE_FILENAME)
optimizer_unet_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_UNET_STATE_FILENAME)
@ -100,37 +93,72 @@ class EveryDreamOptimizer():
self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME))
self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME))
def create_optimizers(self, args, global_optimizer_config, text_encoder, unet):
def create_optimizers(self, args, global_optimizer_config, text_encoder_params, unet_params):
"""
creates optimizers from config and argsfor unet and text encoder
returns (optimizer_te, optimizer_unet)
"""
text_encoder_lr_scale = global_optimizer_config.get("text_encoder_lr_scale")
unet_config = global_optimizer_config.get("unet")
te_config = global_optimizer_config.get("text_encoder")
te_config, unet_config = self.fold(te_config=te_config,
unet_config=unet_config,
text_encoder_lr_scale=text_encoder_lr_scale)
if args.disable_textenc_training:
optimizer_te = create_null_optimizer()
else:
optimizer_te = self.create_optimizer(global_optimizer_config.get("text_encoder"), text_encoder)
optimizer_te = self.create_optimizer(args, te_config, text_encoder_params)
if args.disable_unet_training:
optimizer_unet = create_null_optimizer()
else:
optimizer_unet = self.create_optimizer(global_optimizer_config, unet)
optimizer_unet = self.create_optimizer(args, unet_config, unet_params)
return optimizer_te, optimizer_unet
@staticmethod
def fold(te_config, unet_config, text_encoder_lr_scale):
"""
defaults text encoder config values to unet config values if not specified per property
"""
if te_config.get("optimizer", None) is None:
te_config["optimizer"] = unet_config["optimizer"]
if te_config.get("lr", None) is None:
te_config["lr"] = unet_config["lr"]
te_scale = text_encoder_lr_scale
if te_scale is not None:
te_config["lr"] = unet_config["lr"] * te_scale
if te_config.get("weight_decay", None) is None:
te_config["weight_decay"] = unet_config["weight_decay"]
if te_config.get("betas", None) is None:
te_config["betas"] = unet_config["betas"]
if te_config.get("epsilon", None) is None:
te_config["epsilon"] = unet_config["epsilon"]
if te_config.get("lr_scheduler", None) is None:
te_config["lr_scheduler"] = unet_config["lr_scheduler"]
return te_config, unet_config
def create_lr_schedulers(self, args, optimizer_config):
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", self.unet_config.lr_scheduler)
lr_scheduler_type_unet = optimizer_config["unet"].get("lr_scheduler", None)
assert lr_scheduler_type_unet is not None, "lr_scheduler must be specified in optimizer config"
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", lr_scheduler_type_unet)
self.lr_scheduler_te = get_scheduler(
lr_scheduler_type_te,
optimizer=self.optimizer_te,
num_warmup_steps=lr_warmup_steps,
num_training_steps=args.lr_decay_steps,
)
self.lr_scheduler_unet = get_scheduler(
args.lr_scheduler,
lr_scheduler_type_unet,
optimizer=self.optimizer_unet,
num_warmup_steps=lr_warmup_steps,
num_training_steps=args.lr_decay_steps,
)
return self.lr_scheduler_te, self.lr_scheduler_unet
def update_grad_scaler(self, global_step):
@ -169,8 +197,8 @@ class EveryDreamOptimizer():
"""
optimizer.load_state_dict(torch.load(path))
@staticmethod
def create_optimizer(args, local_optimizer_config, parameters):
def create_optimizer(self, args, local_optimizer_config, parameters):
print(f"Creating optimizer from {local_optimizer_config}")
betas = BETAS_DEFAULT
epsilon = EPSILON_DEFAULT
weight_decay = WEIGHT_DECAY_DEFAULT
@ -182,10 +210,10 @@ class EveryDreamOptimizer():
text_encoder_lr_scale = 1.0
if local_optimizer_config is not None:
betas = local_optimizer_config["betas"]
epsilon = local_optimizer_config["epsilon"]
weight_decay = local_optimizer_config["weight_decay"]
optimizer_name = local_optimizer_config["optimizer"]
betas = local_optimizer_config["betas"] or betas
epsilon = local_optimizer_config["epsilon"] or epsilon
weight_decay = local_optimizer_config["weight_decay"] or weight_decay
optimizer_name = local_optimizer_config["optimizer"] or None
curr_lr = local_optimizer_config.get("lr", curr_lr)
if args.lr is not None:
curr_lr = args.lr
@ -228,18 +256,10 @@ class EveryDreamOptimizer():
)
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5)
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps,
num_training_steps=args.lr_decay_steps,
)
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
return optimizer

View File

@ -261,11 +261,6 @@ def setup_args(args):
total_batch_size = args.batch_size * args.grad_accum
if args.scale_lr is not None and args.scale_lr:
tmp_lr = args.lr
args.lr = args.lr * (total_batch_size**0.55)
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir):
os.makedirs(args.save_ckpt_dir)
@ -550,7 +545,7 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
exit()
log_args(log_writer, args)
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
@ -864,7 +859,7 @@ if __name__ == "__main__":
print("No config file specified, using command line args")
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
#argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
@ -884,7 +879,7 @@ if __name__ == "__main__":
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
#argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
@ -899,7 +894,7 @@ if __name__ == "__main__":
argparser.add_argument("--save_ckpts_from_n_epochs", type=int, default=0, help="Only saves checkpoints starting an N epochs, def: 0 (disabled)")
argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
#argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")

View File

@ -1,15 +1,17 @@
def check_git():
import subprocess
try:
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
branch = result.stdout.strip()
result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True)
branch = result.stdout.strip()
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
ahead, behind = map(int, result.stdout.split())
result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True)
ahead, behind = map(int, result.stdout.split())
if behind > 0:
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
elif ahead > 0:
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
else:
print(f"** Your branch '{branch}' is up to date with the remote")
if behind > 0:
print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.")
elif ahead > 0:
print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.")
else:
print(f"** Your branch '{branch}' is up to date with the remote")
except:
pass