diff --git a/optimizer.json b/optimizer.json index 4ac8d6e..f4b9e8e 100644 --- a/optimizer.json +++ b/optimizer.json @@ -12,7 +12,11 @@ "lr_decay_steps": "number of steps to decay LR to zero for cosine, if null will use CLI or default a value based on max epochs", "betas": "exponential decay rates for the moment estimates", "epsilon": "value added to denominator for numerical stability, unused for lion", - "weight_decay": "weight decay (L2 penalty)" + "weight_decay": "weight decay (L2 penalty)", + "------------------": "-----------------", + "freeze_embeddings": "freeze the text embeddings", + "freeze_front_n_layers": "freeze the front N layers of the text encoder (you can pass eg -2 to leave only the last 2 layers unfrozen)", + "freeze_final_layer_norm": "freeze the final layer norm" }, "base": { "optimizer": "adamw8bit", @@ -33,5 +37,10 @@ "betas": null, "epsilon": null, "weight_decay": null + }, + "text_encoder_freezing": { + "freeze_embeddings": false, + "freeze_front_n_layers": null, + "freeze_final_layer_norm": true } } diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index feda9a1..eb45737 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -17,6 +17,8 @@ limitations under the License. import logging import itertools import os +from itertools import chain +from typing import Generator, Any import torch from torch.cuda.amp import autocast, GradScaler @@ -40,12 +42,13 @@ class EveryDreamOptimizer(): text_encoder: text encoder model parameters unet: unet model parameters """ - def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len): + def __init__(self, args, optimizer_config, text_encoder, unet, epoch_len): del optimizer_config["doc"] print(f"\n raw optimizer_config:") pprint.pprint(optimizer_config) self.epoch_len = epoch_len self.te_config, self.base_config = self.get_final_optimizer_configs(args, optimizer_config) + self.te_freeze_config = optimizer_config.get("text_encoder_freezing", {}) print(f"final unet optimizer config:") pprint.pprint(self.base_config) print(f"final text encoder optimizer config:") @@ -53,11 +56,14 @@ class EveryDreamOptimizer(): self.grad_accum = args.grad_accum self.clip_grad_norm = args.clip_grad_norm - self.text_encoder_params = text_encoder_params - self.unet_params = unet_params + + self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder) + self.unet_params = unet.parameters() self.optimizers = [] - self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, text_encoder_params, unet_params) + self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, + self.text_encoder_params, + self.unet_params) self.optimizers.append(self.optimizer_te) if self.optimizer_te is not None else None self.optimizers.append(self.optimizer_unet) if self.optimizer_unet is not None else None @@ -136,11 +142,11 @@ class EveryDreamOptimizer(): if args.disable_textenc_training: optimizer_te = None else: - optimizer_te = self._create_optimizer(args, self.te_config, text_encoder_params) + optimizer_te = self._create_optimizer("text encoder", args, self.te_config, text_encoder_params) if args.disable_unet_training: optimizer_unet = None else: - optimizer_unet = self._create_optimizer(args, self.base_config, unet_params) + optimizer_unet = self._create_optimizer("unet", args, self.base_config, unet_params) return optimizer_te, optimizer_unet @@ -248,7 +254,7 @@ class EveryDreamOptimizer(): logging.warning(f"{Fore.LIGHTYELLOW_EX}**Failed to load optimizer state from {path}, optimizer state will not be loaded, \n * Exception: {e}{Style.RESET_ALL}") pass - def _create_optimizer(self, args, local_optimizer_config, parameters): + def _create_optimizer(self, label, args, local_optimizer_config, parameters): betas = BETAS_DEFAULT epsilon = EPSILON_DEFAULT weight_decay = WEIGHT_DECAY_DEFAULT @@ -298,12 +304,48 @@ class EveryDreamOptimizer(): amsgrad=False, ) - log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr) + log_optimizer(label, optimizer, betas, epsilon, weight_decay, curr_lr) return optimizer -def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): + def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]: + parameters = itertools.chain([]) + + if self.te_freeze_config.get('freeze_embeddings', False): + # freeze embeddings + print(" ❄️ freezing embeddings") + else: + parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters()) + + freeze_front_n_layers = self.te_freeze_config.get('freeze_front_n_layers', None) + if freeze_front_n_layers is None: + parameters = itertools.chain(parameters, text_encoder.text_model.encoder.layers.parameters()) + else: + # freeze the specified CLIP text encoder layers + layers = text_encoder.text_model.encoder.layers + print(f" ❄️ freezing text encoder layers 0-{len(layers[:freeze_front_n_layers])} of {len(layers)}") + parameters = itertools.chain(parameters, layers[freeze_front_n_layers:].parameters()) + + if self.te_freeze_config.get('freeze_final_layer_norm', False): + # instead of freezing the final layer norm parameters, we simply do not return them + print(" ❄️ freezing final layer norm") + else: + parameters = itertools.chain(parameters, text_encoder.text_model.final_layer_norm.parameters()) + + return parameters + + +def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr): """ logs the optimizer settings """ - logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}") + all_params = sum([g['params'] for g in optimizer.param_groups], []) + frozen_parameter_count = len([p for p in all_params if not p.requires_grad]) + total_parameter_count = len(all_params) + if frozen_parameter_count > 0: + param_info = f"({total_parameter_count} parameters, {frozen_parameter_count} frozen)" + else: + param_info = f"({total_parameter_count} parameters)" + + logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}") logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}") + diff --git a/train.py b/train.py index de58cf1..41f5b7c 100644 --- a/train.py +++ b/train.py @@ -543,7 +543,11 @@ def main(args): epoch_len = math.ceil(len(train_batch) / args.batch_size) - ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len) + ed_optimizer = EveryDreamOptimizer(args, + optimizer_config, + text_encoder, + unet, + epoch_len) log_args(log_writer, args)