refactor optimizer to split te and unet
This commit is contained in:
parent
f0449c64e7
commit
3639e36135
|
@ -1,17 +1,32 @@
|
|||
{
|
||||
"doc": {
|
||||
"unet": "unet config",
|
||||
"text_encoder": "text encoder config, if properties are null copies from unet config",
|
||||
"text_encoder_lr_scale": "if LR not set on text encoder, sets the Lr to a multiple of the Unet LR. for example, if unet `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`.",
|
||||
"-----------------": "-----------------",
|
||||
"optimizer": "adamw, adamw8bit, lion",
|
||||
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
|
||||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"lr": "learning rate, if null will use CLI or main JSON config value",
|
||||
"lr_scheduler": "overrides global lr scheduler from main config",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)",
|
||||
"text_encoder_lr_scale": "scale the text encoder LR relative to the Unet LR. for example, if `lr` is 2e-6 and `text_encoder_lr_scale` is 0.5, the text encoder's LR will be set to `1e-6`."
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
},
|
||||
"text_encoder_lr_scale": 0.5,
|
||||
"unet": {
|
||||
"optimizer": "adamw8bit",
|
||||
"lr": 1e-6,
|
||||
"lr_scheduler": null,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010,
|
||||
"text_encoder_lr_scale": 0.50
|
||||
"weight_decay": 0.010
|
||||
},
|
||||
"text_encoder": {
|
||||
"optimizer": null,
|
||||
"lr": null,
|
||||
"lr_scheduler": null,
|
||||
"betas": null,
|
||||
"epsilon": null,
|
||||
"weight_decay": null
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
import logging
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
BETAS_DEFAULT = [0.9, 0.999]
|
||||
EPSILON_DEFAULT = 1e-8
|
||||
WEIGHT_DECAY_DEFAULT = 0.01
|
||||
LR_DEFAULT = 1e-6
|
||||
OPTIMIZER_TE_STATE_FILENAME = "optimizer_te.pt"
|
||||
OPTIMIZER_UNET_STATE_FILENAME = "optimizer_unet.pt"
|
||||
|
||||
class EveryDreamOptimizer():
|
||||
"""
|
||||
Wrapper to manage optimizers
|
||||
resume_ckpt_path: path to resume checkpoint, will load state files if they exist
|
||||
optimizer_config: config for the optimizer
|
||||
text_encoder: text encoder model
|
||||
unet: unet model
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params):
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
self.unet_params = unet_params
|
||||
|
||||
self.optimizer_te, self.optimizer_unet = self.create_optimizers(args, optimizer_config, text_encoder_params, unet_params)
|
||||
self.lr_scheduler_te, self.lr_scheduler_unet = self.create_lr_schedulers(args, optimizer_config)
|
||||
|
||||
self.unet_config = optimizer_config.get("unet", {})
|
||||
if args.lr is not None:
|
||||
self.unet_config["lr"] = args.lr
|
||||
self.te_config = optimizer_config.get("text_encoder", {})
|
||||
if self.te_config.get("lr", None) is None:
|
||||
self.te_config["lr"] = self.unet_config["lr"]
|
||||
te_scale = self.optimizer_config.get("text_encoder_lr_scale", None)
|
||||
if te_scale is not None:
|
||||
self.te_config["lr"] = self.unet_config["lr"] * te_scale
|
||||
|
||||
optimizer_te_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_TE_STATE_FILENAME)
|
||||
optimizer_unet_state_path = os.path.join(args.resume_ckpt, OPTIMIZER_UNET_STATE_FILENAME)
|
||||
if os.path.exists(optimizer_te_state_path):
|
||||
logging.info(f"Loading text encoder optimizer state from {optimizer_te_state_path}")
|
||||
self.load_optimizer_state(self.optimizer_te, optimizer_te_state_path)
|
||||
if os.path.exists(optimizer_unet_state_path):
|
||||
logging.info(f"Loading unet optimizer state from {optimizer_unet_state_path}")
|
||||
self.load_optimizer_state(self.optimizer_unet, optimizer_unet_state_path)
|
||||
|
||||
self.scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
|
||||
logging.info(f" Grad scaler enabled: {self.scaler.is_enabled()} (amp mode)")
|
||||
|
||||
def step(self, loss, step, global_step):
|
||||
self.scaler.scale(loss).backward()
|
||||
self.optimizer_te.step()
|
||||
self.optimizer_unet.step()
|
||||
|
||||
if self.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == epoch_len - 1):
|
||||
self.scaler.step(self.optimizer_te)
|
||||
self.scaler.step(self.optimizer_unet)
|
||||
self.scaler.update()
|
||||
self._zero_grad(set_to_none=True)
|
||||
|
||||
self.lr_scheduler.step()
|
||||
|
||||
self.optimizer_unet.step()
|
||||
self.update_grad_scaler(global_step)
|
||||
|
||||
def _zero_grad(self, set_to_none=False):
|
||||
self.optimizer_te.zero_grad(set_to_none=set_to_none)
|
||||
self.optimizer_unet.zero_grad(set_to_none=set_to_none)
|
||||
|
||||
def get_scale(self):
|
||||
return self.scaler.get_scale()
|
||||
|
||||
def get_unet_lr(self):
|
||||
return self.optimizer_unet.param_groups[0]['lr']
|
||||
|
||||
def get_te_lr(self):
|
||||
return self.optimizer_te.param_groups[0]['lr']
|
||||
|
||||
def save(self, ckpt_path: str):
|
||||
"""
|
||||
Saves the optimizer states to path
|
||||
"""
|
||||
self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME))
|
||||
self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME))
|
||||
|
||||
def create_optimizers(self, args, global_optimizer_config, text_encoder, unet):
|
||||
"""
|
||||
creates optimizers from config and argsfor unet and text encoder
|
||||
returns (optimizer_te, optimizer_unet)
|
||||
"""
|
||||
if args.disable_textenc_training:
|
||||
optimizer_te = create_null_optimizer()
|
||||
else:
|
||||
optimizer_te = self.create_optimizer(global_optimizer_config.get("text_encoder"), text_encoder)
|
||||
if args.disable_unet_training:
|
||||
optimizer_unet = create_null_optimizer()
|
||||
else:
|
||||
optimizer_unet = self.create_optimizer(global_optimizer_config, unet)
|
||||
|
||||
return optimizer_te, optimizer_unet
|
||||
|
||||
def create_lr_schedulers(self, args, optimizer_config):
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", self.unet_config.lr_scheduler)
|
||||
self.lr_scheduler_te = get_scheduler(
|
||||
lr_scheduler_type_te,
|
||||
optimizer=self.optimizer_te,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
self.lr_scheduler_unet = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=self.optimizer_unet,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
return self.lr_scheduler_te, self.lr_scheduler_unet
|
||||
|
||||
def update_grad_scaler(self, global_step):
|
||||
if global_step == 500:
|
||||
factor = 1.8
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(50)
|
||||
if global_step == 1000:
|
||||
factor = 1.6
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(50)
|
||||
if global_step == 2000:
|
||||
factor = 1.3
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(100)
|
||||
if global_step == 4000:
|
||||
factor = 1.15
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(100)
|
||||
|
||||
@staticmethod
|
||||
def _save_optimizer(optimizer, path: str):
|
||||
"""
|
||||
Saves the optimizer state to specific path/filename
|
||||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
@staticmethod
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state to an Optimizer object
|
||||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
|
||||
@staticmethod
|
||||
def create_optimizer(args, local_optimizer_config, parameters):
|
||||
betas = BETAS_DEFAULT
|
||||
epsilon = EPSILON_DEFAULT
|
||||
weight_decay = WEIGHT_DECAY_DEFAULT
|
||||
opt_class = None
|
||||
optimizer = None
|
||||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
betas = local_optimizer_config["betas"]
|
||||
epsilon = local_optimizer_config["epsilon"]
|
||||
weight_decay = local_optimizer_config["weight_decay"]
|
||||
optimizer_name = local_optimizer_config["optimizer"]
|
||||
curr_lr = local_optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
text_encoder_lr_scale = local_optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||
if text_encoder_lr_scale != 1.0:
|
||||
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name in ["adamw"]:
|
||||
opt_class = torch.optim.AdamW
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_null_optimizer():
|
||||
return torch.optim.AdamW([torch.zeros(1)], lr=0)
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, lr, model_name):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer {model_name}: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} lr: {lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
196
train.py
196
train.py
|
@ -29,7 +29,7 @@ import traceback
|
|||
import shutil
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from colorama import Fore, Style
|
||||
import numpy as np
|
||||
|
@ -60,6 +60,7 @@ from utils.huggingface_downloader import try_download_model_from_hf
|
|||
from utils.convert_diff_to_ckpt import convert as converter
|
||||
from utils.isolate_rng import isolate_rng
|
||||
from utils.check_git import check_git
|
||||
from optimizer.optimizers import EveryDreamOptimizer
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
|
@ -131,24 +132,17 @@ def setup_local_logger(args):
|
|||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay, unet_lr, text_encoder_lr):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} unet lr: {unet_lr}, text encoder lr: {text_encoder_lr}, betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
# """
|
||||
# Saves the optimizer state
|
||||
# """
|
||||
# torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Saves the optimizer state
|
||||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state
|
||||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
# def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
# """
|
||||
# Loads the optimizer state
|
||||
# """
|
||||
# optimizer.load_state_dict(torch.load(path))
|
||||
|
||||
def get_gpu_memory(nvsmi):
|
||||
"""
|
||||
|
@ -284,28 +278,6 @@ def setup_args(args):
|
|||
|
||||
return args
|
||||
|
||||
def update_grad_scaler(scaler: GradScaler, global_step, epoch, step):
|
||||
if global_step == 500:
|
||||
factor = 1.8
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 1000:
|
||||
factor = 1.6
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(50)
|
||||
if global_step == 2000:
|
||||
factor = 1.3
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
if global_step == 4000:
|
||||
factor = 1.15
|
||||
scaler.set_growth_factor(factor)
|
||||
scaler.set_backoff_factor(1/factor)
|
||||
scaler.set_growth_interval(100)
|
||||
|
||||
|
||||
def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem], batch_size) -> None:
|
||||
undersized_items = [item for item in items if item.is_undersized]
|
||||
|
@ -453,7 +425,6 @@ def main(args):
|
|||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
|
||||
if save_optimizer_flag:
|
||||
optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
|
@ -531,8 +502,6 @@ def main(args):
|
|||
project=args.project_name,
|
||||
config={"main_cfg": vars(args), "optimizer_cfg": optimizer_config},
|
||||
name=args.run_name,
|
||||
#sync_tensorboard=True, # broken?
|
||||
#dir=log_folder, # only for save, just duplicates the TB log to /{log_folder}/wandb ...
|
||||
)
|
||||
try:
|
||||
if webbrowser.get():
|
||||
|
@ -545,84 +514,6 @@ def main(args):
|
|||
comment=args.run_name if args.run_name is not None else log_time,
|
||||
)
|
||||
|
||||
betas = [0.9, 0.999]
|
||||
epsilon = 1e-8
|
||||
weight_decay = 0.01
|
||||
opt_class = None
|
||||
optimizer = None
|
||||
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
text_encoder_lr_scale = 1.0
|
||||
|
||||
if optimizer_config is not None:
|
||||
betas = optimizer_config["betas"]
|
||||
epsilon = optimizer_config["epsilon"]
|
||||
weight_decay = optimizer_config["weight_decay"]
|
||||
optimizer_name = optimizer_config["optimizer"]
|
||||
curr_lr = optimizer_config.get("lr", curr_lr)
|
||||
if args.lr is not None:
|
||||
curr_lr = args.lr
|
||||
logging.info(f"Overriding LR from optimizer config with main config/cli LR setting: {curr_lr}")
|
||||
|
||||
text_encoder_lr_scale = optimizer_config.get("text_encoder_lr_scale", text_encoder_lr_scale)
|
||||
if text_encoder_lr_scale != 1.0:
|
||||
logging.info(f" * Using text encoder LR scale {text_encoder_lr_scale}")
|
||||
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
logging.warning(f"No LR setting found, defaulting to {default_lr}")
|
||||
|
||||
curr_text_encoder_lr = curr_lr * text_encoder_lr_scale
|
||||
|
||||
if args.disable_textenc_training:
|
||||
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters())
|
||||
elif args.disable_unet_training:
|
||||
logging.info(f"{Fore.CYAN} * Training Text Encoder Only *{Style.RESET_ALL}")
|
||||
if text_encoder_lr_scale != 1:
|
||||
logging.warning(f"{Fore.YELLOW} * Ignoring text_encoder_lr_scale {text_encoder_lr_scale} and using the "
|
||||
f"Unet LR {curr_lr} for the text encoder instead *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(text_encoder.parameters())
|
||||
else:
|
||||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = [{'params': unet.parameters()},
|
||||
{'params': text_encoder.parameters(), 'lr': curr_text_encoder_lr}]
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name in ["adamw"]:
|
||||
opt_class = torch.optim.AdamW
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
if optimizer_state_path is not None:
|
||||
logging.info(f"Loading optimizer state from {optimizer_state_path}")
|
||||
load_optimizer(optimizer, optimizer_state_path)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
|
||||
image_train_items = resolve_image_train_items(args)
|
||||
|
||||
validator = None
|
||||
|
@ -658,17 +549,7 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps,
|
||||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters())
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
|
@ -742,15 +623,6 @@ def main(args):
|
|||
logging.info(f" {Fore.GREEN}batch_size: {Style.RESET_ALL}{Fore.LIGHTGREEN_EX}{args.batch_size}{Style.RESET_ALL}")
|
||||
logging.info(f" {Fore.GREEN}epoch_len: {Fore.LIGHTGREEN_EX}{epoch_len}{Style.RESET_ALL}")
|
||||
|
||||
scaler = GradScaler(
|
||||
enabled=args.amp,
|
||||
init_scale=2**17.5,
|
||||
growth_factor=2,
|
||||
backoff_factor=1.0/2,
|
||||
growth_interval=25,
|
||||
)
|
||||
logging.info(f" Grad scaler enabled: {scaler.is_enabled()} (amp mode)")
|
||||
|
||||
epoch_pbar = tqdm(range(args.max_epochs), position=0, leave=True, dynamic_ncols=True)
|
||||
epoch_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Epochs{Style.RESET_ALL}")
|
||||
epoch_times = []
|
||||
|
@ -868,20 +740,18 @@ def main(args):
|
|||
loss_scale = batch["runt_size"] / args.batch_size
|
||||
loss = loss * loss_scale
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
ed_optimizer.step(step, global_step)
|
||||
|
||||
if args.clip_grad_norm is not None:
|
||||
if not args.disable_unet_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
if not args.disable_textenc_training:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
# if args.clip_grad_norm is not None:
|
||||
# if not args.disable_unet_training:
|
||||
# torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
# if not args.disable_textenc_training:
|
||||
# torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
lr_scheduler.step()
|
||||
#if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
#ed_optimizers.step(step, global_step)
|
||||
#scaler.update()
|
||||
#optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
|
@ -895,23 +765,23 @@ def main(args):
|
|||
loss_epoch.append(loss_step)
|
||||
|
||||
if (global_step + 1) % args.log_step == 0:
|
||||
curr_lr = lr_scheduler.get_last_lr()[0]
|
||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
lr_unet = ed_optimizer.get_unet_lr()
|
||||
lr_textenc = ed_optimizer.get_textenc_lr()
|
||||
loss_log_step = []
|
||||
logs = {"loss/log_step": loss_local, "lr": curr_lr, "img/s": images_per_sec}
|
||||
if args.disable_textenc_training or args.disable_unet_training or text_encoder_lr_scale == 1:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
|
||||
else:
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=curr_lr, global_step=global_step)
|
||||
curr_text_encoder_lr = lr_scheduler.get_last_lr()[1]
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=curr_text_encoder_lr, global_step=global_step)
|
||||
|
||||
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
images_per_sec_log_step = []
|
||||
if args.amp:
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=scaler.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
|
||||
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -933,7 +803,7 @@ def main(args):
|
|||
|
||||
del batch
|
||||
global_step += 1
|
||||
update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
#update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
|
Loading…
Reference in New Issue