diff --git a/optimizer.json b/optimizer.json index 3572127..be8d25d 100644 --- a/optimizer.json +++ b/optimizer.json @@ -1,17 +1,32 @@ { "doc": { + "unet": "unet config", + "text_encoder": "text encoder config, if properties are null copies from unet config", + "text_encoder_lr_scale": "if LR not set on text encoder, sets the Lr to a multiple of the Unet LR. for example, if unet `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`.", + "-----------------": "-----------------", "optimizer": "adamw, adamw8bit, lion", "optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains", - "lr": "learning rate, if null wil use CLI or main JSON config value", + "lr": "learning rate, if null will use CLI or main JSON config value", + "lr_scheduler": "overrides global lr scheduler from main config", "betas": "exponential decay rates for the moment estimates", "epsilon": "value added to denominator for numerical stability, unused for lion", - "weight_decay": "weight decay (L2 penalty)", - "text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`." + "weight_decay": "weight decay (L2 penalty)" }, - "optimizer": "adamw8bit", - "lr": 1e-6, - "betas": [0.9, 0.999], - "epsilon": 1e-8, - "weight_decay": 0.010, - "text_encoder_lr_scale": 0.50 + "text_encoder_lr_scale": 0.5, + "unet": { + "optimizer": "adamw8bit", + "lr": 1e-6, + "lr_scheduler": null, + "betas": [0.9, 0.999], + "epsilon": 1e-8, + "weight_decay": 0.010 + }, + "text_encoder": { + "optimizer": null, + "lr": null, + "lr_scheduler": null, + "betas": null, + "epsilon": null, + "weight_decay": null + } } diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py new file mode 100644 index 0000000..d17996a --- /dev/null +++ b/optimizer/optimizers.py @@ -0,0 +1,257 @@ +import logging +import itertools +import os + +import torch +from torch.cuda.amp import autocast, GradScaler +from diffusers.optimization import get_scheduler + +from colorama import Fore, Style + +BETAS_DEFAULT = [0.9, 0.999] +EPSILON_DEFAULT = 1e-8 +WEIGHT_DECAY_DEFAULT = 0.01 +LR_DEFAULT = 1e-6 +OPTIMIZER_TE_STATE_FILENAME = "optimizer_te.pt" +OPTIMIZER_UNET_STATE_FILENAME = "optimizer_unet.pt" + +class EveryDreamOptimizer(): + """ + Wrapper to manage optimizers + resume_ckpt_path: path to resume checkpoint, will load state files if they exist + optimizer_config: config for the optimizer + text_encoder: text encoder model + unet: unet model + """ + def __init__(self, args, optimizer_config, text_encoder_params, unet_params): + self.grad_accum = args.grad_accum + self.clip_grad_norm = args.clip_grad_norm + self.text_encoder_params = text_encoder_params + self.unet_params = unet_params + + self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params) + self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config) + + self.unet_config = optimizer_config.get("unet", {}) + if args.lr is not None: + self.unet_config["lr"] = args.lr + self.te_config = optimizer_config.get("text_encoder", {}) + if self.te_config.get("lr", None) is None: + self.te_config["lr"] = self.unet_config["lr"] + te_scale = self.optimizer_config.get("text_encoder_lr_scale", None) + if te_scale is not None: + self.te_config["lr"] = self.unet_config["lr"] * te_scale + + optimizer_te_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_TE_STATE_FILENAME) + optimizer_unet_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_UNET_STATE_FILENAME) + if os.path.exists(optimizer_te_state_path): + logging.info(f"Loading text encoder optimizer state from {optimizer_te_state_path}") + self.load_optimizer_state(self.optimizer_te, optimizer_te_state_path) + if os.path.exists(optimizer_unet_state_path): + logging.info(f"Loading unet optimizer state from {optimizer_unet_state_path}") + self.load_optimizer_state(self.optimizer_unet, optimizer_unet_state_path) + + self.scaler = GradScaler( + enabled=args.amp, + init_scale=2**17.5, + growth_factor=2, + backoff_factor=1.0/2, + growth_interval=25, + ) + + logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)") + + def step(self, loss, step, global_step): + self.scaler.scale(loss).backward() + self.optimizer_te.step() + self.optimizer_unet.step() + + if self.clip_grad_norm is not None: + if not args.disable_unet_training: + torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm) + if not args.disable_textenc_training: + torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm) + if ((global_step + 1) % self.grad_accum == 0) or (step == epoch_len - 1): + self.scaler.step(self.optimizer_te) + self.scaler.step(self.optimizer_unet) + self.scaler.update() + self._zero_grad(set_to_none=True) + + self.lr_scheduler.step() + + self.optimizer_unet.step() + self.update_grad_scaler(global_step) + + def _zero_grad(self, set_to_none=False): + self.optimizer_te.zero_grad(set_to_none=set_to_none) + self.optimizer_unet.zero_grad(set_to_none=set_to_none) + + def get_scale(self): + return self.scaler.get_scale() + + def get_unet_lr(self): + return self.optimizer_unet.param_groups[0]['lr'] + + def get_te_lr(self): + return self.optimizer_te.param_groups[0]['lr'] + + def save(self, ckpt_path: str): + """ + Saves the optimizer states to path + """ + self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME)) + self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME)) + + def create_optimizers(self, args, global_optimizer_config, text_encoder, unet): + """ + creates optimizers from config and argsfor unet and text encoder + returns (optimizer_te, optimizer_unet) + """ + if args.disable_textenc_training: + optimizer_te = create_null_optimizer() + else: + optimizer_te = self.create_optimizer(global_optimizer_config.get("text_encoder"), text_encoder) + if args.disable_unet_training: + optimizer_unet = create_null_optimizer() + else: + optimizer_unet = self.create_optimizer(global_optimizer_config, unet) + + return optimizer_te, optimizer_unet + + def create_lr_schedulers(self, args, optimizer_config): + lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps + lr_scheduler_type_te = optimizer_config.get("lr_scheduler", self.unet_config.lr_scheduler) + self.lr_scheduler_te = get_scheduler( + lr_scheduler_type_te, + optimizer=self.optimizer_te, + num_warmup_steps=lr_warmup_steps, + num_training_steps=args.lr_decay_steps, + ) + self.lr_scheduler_unet = get_scheduler( + args.lr_scheduler, + optimizer=self.optimizer_unet, + num_warmup_steps=lr_warmup_steps, + num_training_steps=args.lr_decay_steps, + ) + return self.lr_scheduler_te, self.lr_scheduler_unet + + def update_grad_scaler(self, global_step): + if global_step == 500: + factor = 1.8 + self.scaler.set_growth_factor(factor) + self.scaler.set_backoff_factor(1/factor) + self.scaler.set_growth_interval(50) + if global_step == 1000: + factor = 1.6 + self.scaler.set_growth_factor(factor) + self.scaler.set_backoff_factor(1/factor) + self.scaler.set_growth_interval(50) + if global_step == 2000: + factor = 1.3 + self.scaler.set_growth_factor(factor) + self.scaler.set_backoff_factor(1/factor) + self.scaler.set_growth_interval(100) + if global_step == 4000: + factor = 1.15 + self.scaler.set_growth_factor(factor) + self.scaler.set_backoff_factor(1/factor) + self.scaler.set_growth_interval(100) + + @staticmethod + def _save_optimizer(optimizer, path: str): + """ + Saves the optimizer state to specific path/filename + """ + torch.save(optimizer.state_dict(), path) + + @staticmethod + def load_optimizer_state(optimizer: torch.optim.Optimizer, path: str): + """ + Loads the optimizer state to an Optimizer object + """ + optimizer.load_state_dict(torch.load(path)) + + @staticmethod + def create_optimizer(args, local_optimizer_config, parameters): + betas = BETAS_DEFAULT + epsilon = EPSILON_DEFAULT + weight_decay = WEIGHT_DECAY_DEFAULT + opt_class = None + optimizer = None + + default_lr = 1e-6 + curr_lr = args.lr + text_encoder_lr_scale = 1.0 + + if local_optimizer_config is not None: + betas = local_optimizer_config["betas"] + epsilon = local_optimizer_config["epsilon"] + weight_decay = local_optimizer_config["weight_decay"] + optimizer_name = local_optimizer_config["optimizer"] + curr_lr = local_optimizer_config.get("lr", curr_lr) + if args.lr is not None: + curr_lr = args.lr + logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}") + + text_encoder_lr_scale = local_optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale) + if text_encoder_lr_scale != 1.0: + logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}") + + if curr_lr is None: + curr_lr = default_lr + logging.warning(f"No LR setting found, defaulting to {default_lr}") + + curr_text_encoder_lr = curr_lr * text_encoder_lr_scale + + if optimizer_name: + if optimizer_name == "lion": + from lion_pytorch import Lion + opt_class = Lion + optimizer = opt_class( + itertools.chain(parameters), + lr=curr_lr, + betas=(betas[0], betas[1]), + weight_decay=weight_decay, + ) + elif optimizer_name in ["adamw"]: + opt_class = torch.optim.AdamW + else: + import bitsandbytes as bnb + opt_class = bnb.optim.AdamW8bit + + if not optimizer: + optimizer = opt_class( + itertools.chain(parameters), + lr=curr_lr, + betas=(betas[0], betas[1]), + eps=epsilon, + weight_decay=weight_decay, + amsgrad=False, + ) + + if args.lr_decay_steps is None or args.lr_decay_steps < 1: + args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5) + + lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps, + num_training_steps=args.lr_decay_steps, + ) + + + log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr) + return optimizer + + +def create_null_optimizer(): + return torch.optim.AdamW([torch.zeros(1)], lr=0) + +def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr, model_name): + """ + logs the optimizer settings + """ + logging.info(f"{Fore.CYAN} * Optimizer {model_name}: {optimizer.__class__.__name__} *{Style.RESET_ALL}") + logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") diff --git a/train.py b/train.py index bc930be..c351ac2 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ import traceback import shutil import torch.nn.functional as F -from torch.cuda.amp import autocast, GradScaler +from torch.cuda.amp import autocast from colorama import Fore, Style import numpy as np @@ -60,6 +60,7 @@ from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.isolate_rng import isolate_rng from utils.check_git import check_git +from optimizer.optimizers import EveryDreamOptimizer if torch.cuda.is_available(): from utils.gpu import GPU @@ -131,24 +132,17 @@ def setup_local_logger(args): return datetimestamp -def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr): - """ - logs the optimizer settings - """ - logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") - logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") +# def save_optimizer(optimizer: torch.optim.Optimizer, path: str): +# """ +# Saves the optimizer state +# """ +# torch.save(optimizer.state_dict(), path) -def save_optimizer(optimizer: torch.optim.Optimizer, path: str): - """ - Saves the optimizer state - """ - torch.save(optimizer.state_dict(), path) - -def load_optimizer(optimizer: torch.optim.Optimizer, path: str): - """ - Loads the optimizer state - """ - optimizer.load_state_dict(torch.load(path)) +# def load_optimizer(optimizer: torch.optim.Optimizer, path: str): +# """ +# Loads the optimizer state +# """ +# optimizer.load_state_dict(torch.load(path)) def get_gpu_memory(nvsmi): """ @@ -284,28 +278,6 @@ def setup_args(args): return args -def update_grad_scaler(scaler: GradScaler, global_step, epoch, step): - if global_step == 500: - factor = 1.8 - scaler.set_growth_factor(factor) - scaler.set_backoff_factor(1/factor) - scaler.set_growth_interval(50) - if global_step == 1000: - factor = 1.6 - scaler.set_growth_factor(factor) - scaler.set_backoff_factor(1/factor) - scaler.set_growth_interval(50) - if global_step == 2000: - factor = 1.3 - scaler.set_growth_factor(factor) - scaler.set_backoff_factor(1/factor) - scaler.set_growth_interval(100) - if global_step == 4000: - factor = 1.15 - scaler.set_growth_factor(factor) - scaler.set_backoff_factor(1/factor) - scaler.set_growth_interval(100) - def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None: undersized_items = [item for item in items if item.is_undersized] @@ -453,7 +425,6 @@ def main(args): logging.info(f" * Saving yaml to {yaml_save_path}") shutil.copyfile(yaml_name, yaml_save_path) - if save_optimizer_flag: optimizer_path = os.path.join(save_path, "optimizer.pt") logging.info(f" Saving optimizer state to {save_path}") @@ -520,7 +491,7 @@ def main(args): text_encoder = text_encoder.to(device, dtype=torch.float32) optimizer_config = None - optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json" + optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json" if os.path.exists(os.path.join(os.curdir, optimizer_config_path)): with open(os.path.join(os.curdir, optimizer_config_path), "r") as f: optimizer_config = json.load(f) @@ -531,8 +502,6 @@ def main(args): project=args.project_name, config={"main_cfg": vars(args), "optimizer_cfg": optimizer_config}, name=args.run_name, - #sync_tensorboard=True, # broken? - #dir=log_folder, # only for save, just duplicates the TB log to /{log_folder}/wandb ... ) try: if webbrowser.get(): @@ -545,84 +514,6 @@ def main(args): comment=args.run_name if args.run_name is not None else log_time, ) - betas = [0.9, 0.999] - epsilon = 1e-8 - weight_decay = 0.01 - opt_class = None - optimizer = None - - default_lr = 1e-6 - curr_lr = args.lr - text_encoder_lr_scale = 1.0 - - if optimizer_config is not None: - betas = optimizer_config["betas"] - epsilon = optimizer_config["epsilon"] - weight_decay = optimizer_config["weight_decay"] - optimizer_name = optimizer_config["optimizer"] - curr_lr = optimizer_config.get("lr", curr_lr) - if args.lr is not None: - curr_lr = args.lr - logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}") - - text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale) - if text_encoder_lr_scale != 1.0: - logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}") - - logging.info(f" * Loaded optimizer args from {optimizer_config_path} *") - - if curr_lr is None: - curr_lr = default_lr - logging.warning(f"No LR setting found, defaulting to {default_lr}") - - curr_text_encoder_lr = curr_lr * text_encoder_lr_scale - - if args.disable_textenc_training: - logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}") - params_to_train = itertools.chain(unet.parameters()) - elif args.disable_unet_training: - logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}") - if text_encoder_lr_scale != 1: - logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the " - f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}") - params_to_train = itertools.chain(text_encoder.parameters()) - else: - logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}") - params_to_train = [{'params': unet.parameters()}, - {'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}] - - if optimizer_name: - if optimizer_name == "lion": - from lion_pytorch import Lion - opt_class = Lion - optimizer = opt_class( - itertools.chain(params_to_train), - lr=curr_lr, - betas=(betas[0], betas[1]), - weight_decay=weight_decay, - ) - elif optimizer_name in ["adamw"]: - opt_class = torch.optim.AdamW - else: - import bitsandbytes as bnb - opt_class = bnb.optim.AdamW8bit - - if not optimizer: - optimizer = opt_class( - itertools.chain(params_to_train), - lr=curr_lr, - betas=(betas[0], betas[1]), - eps=epsilon, - weight_decay=weight_decay, - amsgrad=False, - ) - - if optimizer_state_path is not None: - logging.info(f"Loading optimizer state from {optimizer_state_path}") - load_optimizer(optimizer, optimizer_state_path) - - log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr) - image_train_items = resolve_image_train_items(args) validator = None @@ -658,17 +549,7 @@ def main(args): epoch_len = math.ceil(len(train_batch) / args.batch_size) - if args.lr_decay_steps is None or args.lr_decay_steps < 1: - args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5) - - lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=lr_warmup_steps, - num_training_steps=args.lr_decay_steps, - ) + ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters()) log_args(log_writer, args) @@ -742,15 +623,6 @@ def main(args): logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}") logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}") - scaler = GradScaler( - enabled=args.amp, - init_scale=2**17.5, - growth_factor=2, - backoff_factor=1.0/2, - growth_interval=25, - ) - logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)") - epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True) epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}") epoch_times = [] @@ -868,20 +740,18 @@ def main(args): loss_scale = batch["runt_size"] / args.batch_size loss = loss * loss_scale - scaler.scale(loss).backward() + ed_optimizer.step(step, global_step) - if args.clip_grad_norm is not None: - if not args.disable_unet_training: - torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) - if not args.disable_textenc_training: - torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) + # if args.clip_grad_norm is not None: + # if not args.disable_unet_training: + # torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm) + # if not args.disable_textenc_training: + # torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm) - if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad(set_to_none=True) - - lr_scheduler.step() + #if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1): + #ed_optimizers.step(step, global_step) + #scaler.update() + #optimizer.zero_grad(set_to_none=True) loss_step = loss.detach().item() @@ -895,23 +765,23 @@ def main(args): loss_epoch.append(loss_step) if (global_step + 1) % args.log_step == 0: - curr_lr = lr_scheduler.get_last_lr()[0] loss_local = sum(loss_log_step) / len(loss_log_step) + lr_unet = ed_optimizer.get_unet_lr() + lr_textenc = ed_optimizer.get_textenc_lr() loss_log_step = [] - logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec} - if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1: - log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step) - else: - log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step) - curr_text_encoder_lr = lr_scheduler.get_last_lr()[1] - log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step) + + log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=lr_unet, global_step=global_step) + log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=lr_textenc, global_step=global_step) log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step) + sum_img = sum(images_per_sec_log_step) avg = sum_img / len(images_per_sec_log_step) images_per_sec_log_step = [] if args.amp: - log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step) + log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step) log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step) + + logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec} append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs) torch.cuda.empty_cache() @@ -933,7 +803,7 @@ def main(args): del batch global_step += 1 - update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None + #update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None # end of step steps_pbar.close()