From 970065c2061fbc81aef83adbe570ecee90eb13cd Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sun, 30 Apr 2023 09:28:55 -0400 Subject: [PATCH] more wip optimizer splitting --- optimizer.json | 2 +- optimizer/optimizers.py | 76 ++++++++++++++++++++++++++--------------- train.py | 13 +++---- utils/check_git.py | 24 +++++++------ 4 files changed, 66 insertions(+), 49 deletions(-) diff --git a/optimizer.json b/optimizer.json index be8d25d..acb5e30 100644 --- a/optimizer.json +++ b/optimizer.json @@ -16,7 +16,7 @@ "unet": { "optimizer": "adamw8bit", "lr": 1e-6, - "lr_scheduler": null, + "lr_scheduler": "constant", "betas": [0.9, 0.999], "epsilon": 1e-8, "weight_decay": 0.010 diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index a675dbe..5f29970 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -24,6 +24,7 @@ class EveryDreamOptimizer(): unet: unet model """ def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len): + print(f"\noptimizer_config: \n{optimizer_config}\n") self.grad_accum = args.grad_accum self.clip_grad_norm = args.clip_grad_norm self.text_encoder_params = text_encoder_params @@ -34,14 +35,6 @@ class EveryDreamOptimizer(): self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config) self.unet_config = optimizer_config.get("unet", {}) - if args.lr is not None: - self.unet_config["lr"] = args.lr - self.te_config = optimizer_config.get("text_encoder", {}) - if self.te_config.get("lr", None) is None: - self.te_config["lr"] = self.unet_config["lr"] - te_scale = self.optimizer_config.get("text_encoder_lr_scale", None) - if te_scale is not None: - self.te_config["lr"] = self.unet_config["lr"] * te_scale optimizer_te_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_TE_STATE_FILENAME) optimizer_unet_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_UNET_STATE_FILENAME) @@ -100,37 +93,72 @@ class EveryDreamOptimizer(): self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME)) self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME)) - def create_optimizers(self, args, global_optimizer_config, text_encoder, unet): + def create_optimizers(self, args, global_optimizer_config, text_encoder_params, unet_params): """ creates optimizers from config and argsfor unet and text encoder returns (optimizer_te, optimizer_unet) """ + text_encoder_lr_scale = global_optimizer_config.get("text_encoder_lr_scale") + unet_config = global_optimizer_config.get("unet") + te_config = global_optimizer_config.get("text_encoder") + te_config, unet_config = self.fold(te_config=te_config, + unet_config=unet_config, + text_encoder_lr_scale=text_encoder_lr_scale) + if args.disable_textenc_training: optimizer_te = create_null_optimizer() else: - optimizer_te = self.create_optimizer(global_optimizer_config.get("text_encoder"), text_encoder) + optimizer_te = self.create_optimizer(args, te_config, text_encoder_params) if args.disable_unet_training: optimizer_unet = create_null_optimizer() else: - optimizer_unet = self.create_optimizer(global_optimizer_config, unet) + optimizer_unet = self.create_optimizer(args, unet_config, unet_params) return optimizer_te, optimizer_unet + @staticmethod + def fold(te_config, unet_config, text_encoder_lr_scale): + """ + defaults text encoder config values to unet config values if not specified per property + """ + if te_config.get("optimizer", None) is None: + te_config["optimizer"] = unet_config["optimizer"] + if te_config.get("lr", None) is None: + te_config["lr"] = unet_config["lr"] + te_scale = text_encoder_lr_scale + if te_scale is not None: + te_config["lr"] = unet_config["lr"] * te_scale + if te_config.get("weight_decay", None) is None: + te_config["weight_decay"] = unet_config["weight_decay"] + if te_config.get("betas", None) is None: + te_config["betas"] = unet_config["betas"] + if te_config.get("epsilon", None) is None: + te_config["epsilon"] = unet_config["epsilon"] + if te_config.get("lr_scheduler", None) is None: + te_config["lr_scheduler"] = unet_config["lr_scheduler"] + + return te_config, unet_config + def create_lr_schedulers(self, args, optimizer_config): lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps - lr_scheduler_type_te = optimizer_config.get("lr_scheduler", self.unet_config.lr_scheduler) + lr_scheduler_type_unet = optimizer_config["unet"].get("lr_scheduler", None) + assert lr_scheduler_type_unet is not None, "lr_scheduler must be specified in optimizer config" + lr_scheduler_type_te = optimizer_config.get("lr_scheduler", lr_scheduler_type_unet) + self.lr_scheduler_te = get_scheduler( lr_scheduler_type_te, optimizer=self.optimizer_te, num_warmup_steps=lr_warmup_steps, num_training_steps=args.lr_decay_steps, ) + self.lr_scheduler_unet = get_scheduler( - args.lr_scheduler, + lr_scheduler_type_unet, optimizer=self.optimizer_unet, num_warmup_steps=lr_warmup_steps, num_training_steps=args.lr_decay_steps, ) + return self.lr_scheduler_te, self.lr_scheduler_unet def update_grad_scaler(self, global_step): @@ -169,8 +197,8 @@ class EveryDreamOptimizer(): """ optimizer.load_state_dict(torch.load(path)) - @staticmethod - def create_optimizer(args, local_optimizer_config, parameters): + def create_optimizer(self, args, local_optimizer_config, parameters): + print(f"Creating optimizer from {local_optimizer_config}") betas = BETAS_DEFAULT epsilon = EPSILON_DEFAULT weight_decay = WEIGHT_DECAY_DEFAULT @@ -182,10 +210,10 @@ class EveryDreamOptimizer(): text_encoder_lr_scale = 1.0 if local_optimizer_config is not None: - betas = local_optimizer_config["betas"] - epsilon = local_optimizer_config["epsilon"] - weight_decay = local_optimizer_config["weight_decay"] - optimizer_name = local_optimizer_config["optimizer"] + betas = local_optimizer_config["betas"] or betas + epsilon = local_optimizer_config["epsilon"] or epsilon + weight_decay = local_optimizer_config["weight_decay"] or weight_decay + optimizer_name = local_optimizer_config["optimizer"] or None curr_lr = local_optimizer_config.get("lr", curr_lr) if args.lr is not None: curr_lr = args.lr @@ -228,18 +256,10 @@ class EveryDreamOptimizer(): ) if args.lr_decay_steps is None or args.lr_decay_steps < 1: - args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5) + args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5) lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=lr_warmup_steps, - num_training_steps=args.lr_decay_steps, - ) - - log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr) return optimizer diff --git a/train.py b/train.py index a3b1c4d..885b2d6 100644 --- a/train.py +++ b/train.py @@ -261,11 +261,6 @@ def setup_args(args): total_batch_size = args.batch_size * args.grad_accum - if args.scale_lr is not None and args.scale_lr: - tmp_lr = args.lr - args.lr = args.lr * (total_batch_size**0.55) - logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}") - if args.save_ckpt_dir is not None and not os.path.exists(args.save_ckpt_dir): os.makedirs(args.save_ckpt_dir) @@ -550,7 +545,7 @@ def main(args): epoch_len = math.ceil(len(train_batch) / args.batch_size) ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len) - + exit() log_args(log_writer, args) sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer, @@ -864,7 +859,7 @@ if __name__ == "__main__": print("No config file specified, using command line args") argparser = argparse.ArgumentParser(description="EveryDream2 Training options") - #argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") + argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP") argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") @@ -884,7 +879,7 @@ if __name__ == "__main__": argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu") argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve") argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set") - argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) + #argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"]) argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant") argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for") argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'") @@ -899,7 +894,7 @@ if __name__ == "__main__": argparser.add_argument("--save_ckpts_from_n_epochs", type=int, default=0, help="Only saves checkpoints starting an N epochs, def: 0 (disabled)") argparser.add_argument("--save_full_precision", action="store_true", default=False, help="save ckpts at full FP32") argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later") - argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") + #argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead") diff --git a/utils/check_git.py b/utils/check_git.py index d46f319..86b4c1a 100644 --- a/utils/check_git.py +++ b/utils/check_git.py @@ -1,15 +1,17 @@ def check_git(): import subprocess + try: + result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True) + branch = result.stdout.strip() - result = subprocess.run(["git", "symbolic-ref", "--short", "HEAD"], capture_output=True, text=True) - branch = result.stdout.strip() + result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True) + ahead, behind = map(int, result.stdout.split()) - result = subprocess.run(["git", "rev-list", "--left-right", "--count", f"origin/{branch}...{branch}"], capture_output=True, text=True) - ahead, behind = map(int, result.stdout.split()) - - if behind > 0: - print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.") - elif ahead > 0: - print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.") - else: - print(f"** Your branch '{branch}' is up to date with the remote") \ No newline at end of file + if behind > 0: + print(f"** Your branch '{branch}' is {behind} commit(s) behind the remote. Consider running 'git pull'.") + elif ahead > 0: + print(f"** Your branch '{branch}' is {ahead} commit(s) ahead the remote, consider a pull request.") + else: + print(f"** Your branch '{branch}' is up to date with the remote") + except: + pass \ No newline at end of file