From a6610625eb684205df4fd23a2d3643899efb171c Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 14 May 2023 11:49:11 +0200 Subject: [PATCH] Squashed commit of the following: commit 0f890f2d6bbccee225f738934f4c4450323f19a2 Merge: c008c40 003b089 Author: Damian Stewart Date: Sun May 14 11:47:40 2023 +0200 Merge remote-tracking branch 'upstream/main' into feat_te_last_n_layers_unsquashed commit c008c404f19ebc6b78085f42a4e39aeb2ba00d04 Author: Damian Stewart Date: Sun May 14 11:23:20 2023 +0200 finalize TE layer freezing commit 7377b10d59e32a6fea5d321a598ae4504e1a9f36 Author: Damian Stewart Date: Thu May 11 20:45:28 2023 +0200 remove zero_lr method commit 4af13ba816c2811d7b5bd6fbb81a32bca6747e99 Author: Damian Stewart Date: Thu May 11 20:05:01 2023 +0200 Revert "rename parameters" This reverts commit aa33c61337599ab2d90b34aaf8c3d36fd4edf147. commit aa33c61337599ab2d90b34aaf8c3d36fd4edf147 Author: Damian Stewart Date: Tue May 9 00:28:00 2023 +0200 rename parameters commit 1da867e6fadb873da2571371a73b522406d76a18 Author: Damian Stewart Date: Sun May 7 22:28:29 2023 +0200 remove silly check commit 483cb2a635c3fe5a044edf4ea8de095bedc3f0ac Author: Damian Stewart Date: Sun May 7 20:53:43 2023 +0200 use 1e-10 not 0 as 'zero' lr commit e5d230e6c765a7e25dc6381d09bd0a66a9a54ec2 Author: Damian Stewart Date: Sun May 7 20:51:51 2023 +0200 add experimental 'zero_lr' freeze method commit bcf24ee59a443c0ee71d622e65e1043b547f845e Author: Damian Stewart Date: Sun May 7 17:32:11 2023 +0200 fix layer selection bug commit 7ee33eff8740e095f85042dcbb792e025b179c6c Author: Damian Stewart Date: Sun May 7 17:25:25 2023 +0200 put back the 'drop' method and make accessible commit 76dfbf6dd6f43f3aa9a7f4629baa8e86573d9520 Author: Damian Stewart Date: Sun May 7 16:39:05 2023 +0200 wip getting final_layer_norm to work commit a19d43651a87525251106ed57238cd2cd1c3f3ff Author: Damian Stewart Date: Sun May 7 16:15:53 2023 +0200 work around a crash when freeze_final_layer_norm is True commit c2a44eb25132941b92e2ecd0be3682ae3c6838c2 Author: Damian Stewart Date: Sun May 7 15:47:10 2023 +0200 improve logging, add extra freezing controls commit a31e64c4c0d12dfb6583dd6f22c8c09ba7840410 Author: Damian Stewart Date: Sun May 7 13:46:38 2023 +0200 alternative method to freeze early TE layers commit 095692fd4ea53707c012217898321860d8b9329f Merge: 876072c 4c5ce81 Author: Damian Stewart Date: Sun May 7 11:52:51 2023 +0200 Merge branch 'victorchall:main' into feat_te_last_n_layers commit 876072c46394fde721a6026f7a6ef72ccb150ddb Author: Damian Stewart Date: Sun May 7 01:41:50 2023 +0200 implement last N layers training only for TE --- optimizer.json | 11 +++++++- optimizer/optimizers.py | 62 ++++++++++++++++++++++++++++++++++------- train.py | 6 +++- 3 files changed, 67 insertions(+), 12 deletions(-) diff --git a/optimizer.json b/optimizer.json index 4ac8d6e..f4b9e8e 100644 --- a/optimizer.json +++ b/optimizer.json @@ -12,7 +12,11 @@ "lr_decay_steps": "number of steps to decay LR to zero for cosine, if null will use CLI or default a value based on max epochs", "betas": "exponential decay rates for the moment estimates", "epsilon": "value added to denominator for numerical stability, unused for lion", - "weight_decay": "weight decay (L2 penalty)" + "weight_decay": "weight decay (L2 penalty)", + "------------------": "-----------------", + "freeze_embeddings": "freeze the text embeddings", + "freeze_front_n_layers": "freeze the front N layers of the text encoder (you can pass eg -2 to leave only the last 2 layers unfrozen)", + "freeze_final_layer_norm": "freeze the final layer norm" }, "base": { "optimizer": "adamw8bit", @@ -33,5 +37,10 @@ "betas": null, "epsilon": null, "weight_decay": null + }, + "text_encoder_freezing": { + "freeze_embeddings": false, + "freeze_front_n_layers": null, + "freeze_final_layer_norm": true } } diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index feda9a1..eb45737 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -17,6 +17,8 @@ limitations under the License. import logging import itertools import os +from itertools import chain +from typing import Generator, Any import torch from torch.cuda.amp import autocast, GradScaler @@ -40,12 +42,13 @@ class EveryDreamOptimizer(): text_encoder: text encoder model parameters unet: unet model parameters """ - def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len): + def __init__(self, args, optimizer_config, text_encoder, unet, epoch_len): del optimizer_config["doc"] print(f"\n raw optimizer_config:") pprint.pprint(optimizer_config) self.epoch_len = epoch_len self.te_config, self.base_config = self.get_final_optimizer_configs(args, optimizer_config) + self.te_freeze_config = optimizer_config.get("text_encoder_freezing", {}) print(f"final unet optimizer config:") pprint.pprint(self.base_config) print(f"final text encoder optimizer config:") @@ -53,11 +56,14 @@ class EveryDreamOptimizer(): self.grad_accum = args.grad_accum self.clip_grad_norm = args.clip_grad_norm - self.text_encoder_params = text_encoder_params - self.unet_params = unet_params + + self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder) + self.unet_params = unet.parameters() self.optimizers = [] - self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, text_encoder_params, unet_params) + self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, + self.text_encoder_params, + self.unet_params) self.optimizers.append(self.optimizer_te) if self.optimizer_te is not None else None self.optimizers.append(self.optimizer_unet) if self.optimizer_unet is not None else None @@ -136,11 +142,11 @@ class EveryDreamOptimizer(): if args.disable_textenc_training: optimizer_te = None else: - optimizer_te = self._create_optimizer(args, self.te_config, text_encoder_params) + optimizer_te = self._create_optimizer("text encoder", args, self.te_config, text_encoder_params) if args.disable_unet_training: optimizer_unet = None else: - optimizer_unet = self._create_optimizer(args, self.base_config, unet_params) + optimizer_unet = self._create_optimizer("unet", args, self.base_config, unet_params) return optimizer_te, optimizer_unet @@ -248,7 +254,7 @@ class EveryDreamOptimizer(): logging.warning(f"{Fore.LIGHTYELLOW_EX}**Failed to load optimizer state from {path}, optimizer state will not be loaded, \n * Exception: {e}{Style.RESET_ALL}") pass - def _create_optimizer(self, args, local_optimizer_config, parameters): + def _create_optimizer(self, label, args, local_optimizer_config, parameters): betas = BETAS_DEFAULT epsilon = EPSILON_DEFAULT weight_decay = WEIGHT_DECAY_DEFAULT @@ -298,12 +304,48 @@ class EveryDreamOptimizer(): amsgrad=False, ) - log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr) + log_optimizer(label, optimizer, betas, epsilon, weight_decay, curr_lr) return optimizer -def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): + def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: + parameters = itertools.chain([]) + + if self.te_freeze_config.get('freeze_embeddings', False): + # freeze embeddings + print(" ❄️ freezing embeddings") + else: + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters()) + + freeze_front_n_layers = self.te_freeze_config.get('freeze_front_n_layers', None) + if freeze_front_n_layers is None: + parameters = itertools.chain(parameters, text_encoder.text_model.encoder.layers.parameters()) + else: + # freeze the specified CLIP text encoder layers + layers = text_encoder.text_model.encoder.layers + print(f" ❄️ freezing text encoder layers 0-{len(layers[:freeze_front_n_layers])} of {len(layers)}") + parameters = itertools.chain(parameters, layers[freeze_front_n_layers:].parameters()) + + if self.te_freeze_config.get('freeze_final_layer_norm', False): + # instead of freezing the final layer norm parameters, we simply do not return them + print(" ❄️ freezing final layer norm") + else: + parameters = itertools.chain(parameters, text_encoder.text_model.final_layer_norm.parameters()) + + return parameters + + +def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): """ logs the optimizer settings """ - logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") + all_params = sum([g['params'] for g in optimizer.param_groups], []) + frozen_parameter_count = len([p for p in all_params if not p.requires_grad]) + total_parameter_count = len(all_params) + if frozen_parameter_count > 0: + param_info = f"({total_parameter_count} parameters, {frozen_parameter_count} frozen)" + else: + param_info = f"({total_parameter_count} parameters)" + + logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") + diff --git a/train.py b/train.py index de58cf1..41f5b7c 100644 --- a/train.py +++ b/train.py @@ -543,7 +543,11 @@ def main(args): epoch_len = math.ceil(len(train_batch) / args.batch_size) - ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len) + ed_optimizer = EveryDreamOptimizer(args, + optimizer_config, + text_encoder, + unet, + epoch_len) log_args(log_writer, args)