optimizer split works
This commit is contained in:
parent
970065c206
commit
4e81a0eb55
|
@ -7,6 +7,7 @@ from torch.cuda.amp import autocast, GradScaler
|
|||
from diffusers.optimization import get_scheduler
|
||||
|
||||
from colorama import Fore, Style
|
||||
import pprint
|
||||
|
||||
BETAS_DEFAULT = [0.9, 0.999]
|
||||
EPSILON_DEFAULT = 1e-8
|
||||
|
@ -18,13 +19,15 @@ OPTIMIZER_UNET_STATE_FILENAME = "optimizer_unet.pt"
|
|||
class EveryDreamOptimizer():
|
||||
"""
|
||||
Wrapper to manage optimizers
|
||||
resume_ckpt_path: path to resume checkpoint, will load state files if they exist
|
||||
optimizer_config: config for the optimizer
|
||||
text_encoder: text encoder model
|
||||
unet: unet model
|
||||
resume_ckpt_path: path to resume checkpoint, will try to load state (.pt) files if they exist
|
||||
optimizer_config: config for the optimizers
|
||||
text_encoder: text encoder model parameters
|
||||
unet: unet model parameters
|
||||
"""
|
||||
def __init__(self, args, optimizer_config, text_encoder_params, unet_params, epoch_len):
|
||||
print(f"\noptimizer_config: \n{optimizer_config}\n")
|
||||
del optimizer_config["doc"]
|
||||
print(f"\noptimizer_config:")
|
||||
pprint.pprint(optimizer_config)
|
||||
self.grad_accum = args.grad_accum
|
||||
self.clip_grad_norm = args.clip_grad_norm
|
||||
self.text_encoder_params = text_encoder_params
|
||||
|
@ -57,21 +60,20 @@ class EveryDreamOptimizer():
|
|||
|
||||
def step(self, loss, step, global_step):
|
||||
self.scaler.scale(loss).backward()
|
||||
self.optimizer_te.step()
|
||||
self.optimizer_unet.step()
|
||||
|
||||
if self.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
|
||||
self.scaler.step(self.optimizer_te)
|
||||
self.scaler.step(self.optimizer_unet)
|
||||
self.scaler.step(self.optimizer_te)
|
||||
|
||||
self.scaler.update()
|
||||
self._zero_grad(set_to_none=True)
|
||||
|
||||
self.lr_scheduler_unet.step()
|
||||
self.lr_scheduler_te.step()
|
||||
self.update_grad_scaler(global_step)
|
||||
self._update_grad_scaler(global_step)
|
||||
|
||||
def _zero_grad(self, set_to_none=False):
|
||||
self.optimizer_te.zero_grad(set_to_none=set_to_none)
|
||||
|
@ -83,7 +85,7 @@ class EveryDreamOptimizer():
|
|||
def get_unet_lr(self):
|
||||
return self.optimizer_unet.param_groups[0]['lr']
|
||||
|
||||
def get_te_lr(self):
|
||||
def get_textenc_lr(self):
|
||||
return self.optimizer_te.param_groups[0]['lr']
|
||||
|
||||
def save(self, ckpt_path: str):
|
||||
|
@ -93,6 +95,17 @@ class EveryDreamOptimizer():
|
|||
self._save_optimizer(self.optimizer_te, os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME))
|
||||
self._save_optimizer(self.optimizer_unet, os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME))
|
||||
|
||||
def load(self, ckpt_path: str):
|
||||
"""
|
||||
Loads the optimizer states from path
|
||||
"""
|
||||
te_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME)
|
||||
unet_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME)
|
||||
if os.path.exists(te_optimizer_state_path):
|
||||
self._load_optimizer(self.optimizer_unet, te_optimizer_state_path)
|
||||
if os.path.exists(unet_optimizer_state_path):
|
||||
self._load_optimizer(self.optimizer_te, unet_optimizer_state_path)
|
||||
|
||||
def create_optimizers(self, args, global_optimizer_config, text_encoder_params, unet_params):
|
||||
"""
|
||||
creates optimizers from config and argsfor unet and text encoder
|
||||
|
@ -145,6 +158,11 @@ class EveryDreamOptimizer():
|
|||
assert lr_scheduler_type_unet is not None, "lr_scheduler must be specified in optimizer config"
|
||||
lr_scheduler_type_te = optimizer_config.get("lr_scheduler", lr_scheduler_type_unet)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
self.lr_scheduler_te = get_scheduler(
|
||||
lr_scheduler_type_te,
|
||||
optimizer=self.optimizer_te,
|
||||
|
@ -161,27 +179,27 @@ class EveryDreamOptimizer():
|
|||
|
||||
return self.lr_scheduler_te, self.lr_scheduler_unet
|
||||
|
||||
def update_grad_scaler(self, global_step):
|
||||
def _update_grad_scaler(self, global_step):
|
||||
if global_step == 500:
|
||||
factor = 1.8
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(50)
|
||||
self.scaler.set_growth_interval(100)
|
||||
if global_step == 1000:
|
||||
factor = 1.6
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(50)
|
||||
self.scaler.set_growth_interval(200)
|
||||
if global_step == 2000:
|
||||
factor = 1.3
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(100)
|
||||
self.scaler.set_growth_interval(1000)
|
||||
if global_step == 4000:
|
||||
factor = 1.15
|
||||
self.scaler.set_growth_factor(factor)
|
||||
self.scaler.set_backoff_factor(1/factor)
|
||||
self.scaler.set_growth_interval(100)
|
||||
self.scaler.set_growth_interval(2000)
|
||||
|
||||
@staticmethod
|
||||
def _save_optimizer(optimizer, path: str):
|
||||
|
@ -191,11 +209,18 @@ class EveryDreamOptimizer():
|
|||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
@staticmethod
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, path: str):
|
||||
def _load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state to an Optimizer object
|
||||
optimizer: torch.optim.Optimizer
|
||||
path: .pt file
|
||||
"""
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
try:
|
||||
optimizer.load_state_dict(torch.load(path))
|
||||
logging.info(f" Loaded optimizer state from {path}")
|
||||
except Exception as e:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX}**Failed to load optimizer state from {path}, optimizer state will not be loaded, \n * Exception: {e}{Style.RESET_ALL}")
|
||||
pass
|
||||
|
||||
def create_optimizer(self, args, local_optimizer_config, parameters):
|
||||
print(f"Creating optimizer from {local_optimizer_config}")
|
||||
|
@ -254,11 +279,6 @@ class EveryDreamOptimizer():
|
|||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
|
||||
args.lr_decay_steps = int(self.epoch_len * args.max_epochs * 1.5)
|
||||
|
||||
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay, curr_lr, curr_text_encoder_lr)
|
||||
return optimizer
|
||||
|
|
28
train.py
28
train.py
|
@ -382,7 +382,7 @@ def main(args):
|
|||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, save_full_precision=False, save_optimizer_flag=False):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
@ -421,9 +421,8 @@ def main(args):
|
|||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
optimizer_path = os.path.join(save_path, "optimizer.pt")
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
save_optimizer(optimizer, optimizer_path)
|
||||
ed_optimizer.save(save_path)
|
||||
|
||||
optimizer_state_path = None
|
||||
try:
|
||||
|
@ -545,7 +544,7 @@ def main(args):
|
|||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder.parameters(), unet.parameters(), epoch_len)
|
||||
exit()
|
||||
|
||||
log_args(log_writer, args)
|
||||
|
||||
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
|
||||
|
@ -675,6 +674,7 @@ def main(args):
|
|||
del noise, latents, cuda_caption
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
return model_pred, target
|
||||
|
@ -735,18 +735,7 @@ def main(args):
|
|||
loss_scale = batch["runt_size"] / args.batch_size
|
||||
loss = loss * loss_scale
|
||||
|
||||
ed_optimizer.step(step, global_step)
|
||||
|
||||
# if args.clip_grad_norm is not None:
|
||||
# if not args.disable_unet_training:
|
||||
# torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
|
||||
# if not args.disable_textenc_training:
|
||||
# torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
|
||||
|
||||
#if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
|
||||
#ed_optimizers.step(step, global_step)
|
||||
#scaler.update()
|
||||
#optimizer.zero_grad(set_to_none=True)
|
||||
ed_optimizer.step(loss, step, global_step)
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
|
@ -789,16 +778,15 @@ def main(args):
|
|||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
#update_grad_scaler(scaler, global_step, epoch, step) if args.amp else None
|
||||
# end of step
|
||||
|
||||
steps_pbar.close()
|
||||
|
@ -834,7 +822,7 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = os.path.join(f"{log_folder}/ckpts/errored-{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer)
|
||||
raise ex
|
||||
|
||||
logging.info(f"{Fore.LIGHTWHITE_EX} ***************************{Style.RESET_ALL}")
|
||||
|
|
Loading…
Reference in New Issue