Squashed commit of the following:
commit 0f890f2d6bbccee225f738934f4c4450323f19a2 Merge: c008c40003b089
Author: Damian Stewart <d@damianstewart.com> Date: Sun May 14 11:47:40 2023 +0200 Merge remote-tracking branch 'upstream/main' into feat_te_last_n_layers_unsquashed commit c008c404f19ebc6b78085f42a4e39aeb2ba00d04 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 14 11:23:20 2023 +0200 finalize TE layer freezing commit 7377b10d59e32a6fea5d321a598ae4504e1a9f36 Author: Damian Stewart <d@damianstewart.com> Date: Thu May 11 20:45:28 2023 +0200 remove zero_lr method commit 4af13ba816c2811d7b5bd6fbb81a32bca6747e99 Author: Damian Stewart <d@damianstewart.com> Date: Thu May 11 20:05:01 2023 +0200 Revert "rename parameters" This reverts commit aa33c61337599ab2d90b34aaf8c3d36fd4edf147. commit aa33c61337599ab2d90b34aaf8c3d36fd4edf147 Author: Damian Stewart <d@damianstewart.com> Date: Tue May 9 00:28:00 2023 +0200 rename parameters commit 1da867e6fadb873da2571371a73b522406d76a18 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 22:28:29 2023 +0200 remove silly check commit 483cb2a635c3fe5a044edf4ea8de095bedc3f0ac Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 20:53:43 2023 +0200 use 1e-10 not 0 as 'zero' lr commit e5d230e6c765a7e25dc6381d09bd0a66a9a54ec2 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 20:51:51 2023 +0200 add experimental 'zero_lr' freeze method commit bcf24ee59a443c0ee71d622e65e1043b547f845e Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 17:32:11 2023 +0200 fix layer selection bug commit 7ee33eff8740e095f85042dcbb792e025b179c6c Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 17:25:25 2023 +0200 put back the 'drop' method and make accessible commit 76dfbf6dd6f43f3aa9a7f4629baa8e86573d9520 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 16:39:05 2023 +0200 wip getting final_layer_norm to work commit a19d43651a87525251106ed57238cd2cd1c3f3ff Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 16:15:53 2023 +0200 work around a crash when freeze_final_layer_norm is True commit c2a44eb25132941b92e2ecd0be3682ae3c6838c2 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 15:47:10 2023 +0200 improve logging, add extra freezing controls commit a31e64c4c0d12dfb6583dd6f22c8c09ba7840410 Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 13:46:38 2023 +0200 alternative method to freeze early TE layers commit 095692fd4ea53707c012217898321860d8b9329f Merge: 876072c4c5ce81
Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 11:52:51 2023 +0200 Merge branch 'victorchall:main' into feat_te_last_n_layers commit 876072c46394fde721a6026f7a6ef72ccb150ddb Author: Damian Stewart <d@damianstewart.com> Date: Sun May 7 01:41:50 2023 +0200 implement last N layers training only for TE
This commit is contained in:
parent
4a2e0bebdd
commit
a6610625eb
|
@ -12,7 +12,11 @@
|
|||
"lr_decay_steps": "number of steps to decay LR to zero for cosine, if null will use CLI or default a value based on max epochs",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
"weight_decay": "weight decay (L2 penalty)",
|
||||
"------------------": "-----------------",
|
||||
"freeze_embeddings": "freeze the text embeddings",
|
||||
"freeze_front_n_layers": "freeze the front N layers of the text encoder (you can pass eg -2 to leave only the last 2 layers unfrozen)",
|
||||
"freeze_final_layer_norm": "freeze the final layer norm"
|
||||
},
|
||||
"base": {
|
||||
"optimizer": "adamw8bit",
|
||||
|
@ -33,5 +37,10 @@
|
|||
"betas": null,
|
||||
"epsilon": null,
|
||||
"weight_decay": null
|
||||
},
|
||||
"text_encoder_freezing": {
|
||||
"freeze_embeddings": false,
|
||||
"freeze_front_n_layers": null,
|
||||
"freeze_final_layer_norm": true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||
import logging
|
||||
import itertools
|
||||
import os
|
||||
from itertools import chain
|
||||
from typing import Generator, Any
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
@ -40,12 +42,13 @@ class EveryDreamOptimizer():
|
|||
text_encoder: text encoder model parameters
|
||||
unet: unet model parameters
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||
def __init__(self, args, optimizer_config, text_encoder, unet, epoch_len):
|
||||
del optimizer_config["doc"]
|
||||
print(f"\n raw optimizer_config:")
|
||||
pprint.pprint(optimizer_config)
|
||||
self.epoch_len = epoch_len
|
||||
self.te_config, self.base_config = self.get_final_optimizer_configs(args, optimizer_config)
|
||||
self.te_freeze_config = optimizer_config.get("text_encoder_freezing", {})
|
||||
print(f"final unet optimizer config:")
|
||||
pprint.pprint(self.base_config)
|
||||
print(f"final text encoder optimizer config:")
|
||||
|
@ -53,11 +56,14 @@ class EveryDreamOptimizer():
|
|||
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
self.unet_params = unet_params
|
||||
|
||||
self.text_encoder_params = self._apply_text_encoder_freeze(text_encoder)
|
||||
self.unet_params = unet.parameters()
|
||||
|
||||
self.optimizers = []
|
||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, text_encoder_params, unet_params)
|
||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args,
|
||||
self.text_encoder_params,
|
||||
self.unet_params)
|
||||
self.optimizers.append(self.optimizer_te) if self.optimizer_te is not None else None
|
||||
self.optimizers.append(self.optimizer_unet) if self.optimizer_unet is not None else None
|
||||
|
||||
|
@ -136,11 +142,11 @@ class EveryDreamOptimizer():
|
|||
if args.disable_textenc_training:
|
||||
optimizer_te = None
|
||||
else:
|
||||
optimizer_te = self._create_optimizer(args, self.te_config, text_encoder_params)
|
||||
optimizer_te = self._create_optimizer("text encoder", args, self.te_config, text_encoder_params)
|
||||
if args.disable_unet_training:
|
||||
optimizer_unet = None
|
||||
else:
|
||||
optimizer_unet = self._create_optimizer(args, self.base_config, unet_params)
|
||||
optimizer_unet = self._create_optimizer("unet", args, self.base_config, unet_params)
|
||||
|
||||
return optimizer_te, optimizer_unet
|
||||
|
||||
|
@ -248,7 +254,7 @@ class EveryDreamOptimizer():
|
|||
logging.warning(f"{Fore.LIGHTYELLOW_EX}**Failed to load optimizer state from {path}, optimizer state will not be loaded, \n * Exception: {e}{Style.RESET_ALL}")
|
||||
pass
|
||||
|
||||
def _create_optimizer(self, args, local_optimizer_config, parameters):
|
||||
def _create_optimizer(self, label, args, local_optimizer_config, parameters):
|
||||
betas = BETAS_DEFAULT
|
||||
epsilon = EPSILON_DEFAULT
|
||||
weight_decay = WEIGHT_DECAY_DEFAULT
|
||||
|
@ -298,12 +304,48 @@ class EveryDreamOptimizer():
|
|||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
log_optimizer(label, optimizer, betas, epsilon, weight_decay, curr_lr)
|
||||
return optimizer
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
def _apply_text_encoder_freeze(self, text_encoder) -> chain[Any]:
|
||||
parameters = itertools.chain([])
|
||||
|
||||
if self.te_freeze_config.get('freeze_embeddings', False):
|
||||
# freeze embeddings
|
||||
print(" ❄️ freezing embeddings")
|
||||
else:
|
||||
parameters = itertools.chain(parameters, text_encoder.text_model.embeddings.parameters())
|
||||
|
||||
freeze_front_n_layers = self.te_freeze_config.get('freeze_front_n_layers', None)
|
||||
if freeze_front_n_layers is None:
|
||||
parameters = itertools.chain(parameters, text_encoder.text_model.encoder.layers.parameters())
|
||||
else:
|
||||
# freeze the specified CLIP text encoder layers
|
||||
layers = text_encoder.text_model.encoder.layers
|
||||
print(f" ❄️ freezing text encoder layers 0-{len(layers[:freeze_front_n_layers])} of {len(layers)}")
|
||||
parameters = itertools.chain(parameters, layers[freeze_front_n_layers:].parameters())
|
||||
|
||||
if self.te_freeze_config.get('freeze_final_layer_norm', False):
|
||||
# instead of freezing the final layer norm parameters, we simply do not return them
|
||||
print(" ❄️ freezing final layer norm")
|
||||
else:
|
||||
parameters = itertools.chain(parameters, text_encoder.text_model.final_layer_norm.parameters())
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def log_optimizer(label: str, optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
all_params = sum([g['params'] for g in optimizer.param_groups], [])
|
||||
frozen_parameter_count = len([p for p in all_params if not p.requires_grad])
|
||||
total_parameter_count = len(all_params)
|
||||
if frozen_parameter_count > 0:
|
||||
param_info = f"({total_parameter_count} parameters, {frozen_parameter_count} frozen)"
|
||||
else:
|
||||
param_info = f"({total_parameter_count} parameters)"
|
||||
|
||||
logging.info(f"{Fore.CYAN} * {label} optimizer: {optimizer.__class__.__name__} {param_info} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
|
||||
|
|
6
train.py
6
train.py
|
@ -543,7 +543,11 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||
ed_optimizer = EveryDreamOptimizer(args,
|
||||
optimizer_config,
|
||||
text_encoder,
|
||||
unet,
|
||||
epoch_len)
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
|
Loading…
Reference in New Issue