Merge pull request #209 from damian0815/feat_rolling_save
Feature: Rolling save ckpt
This commit is contained in:
commit
6c8d15daab
|
@ -0,0 +1,56 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from plugins.plugins import BasePlugin
|
||||||
|
from train import save_model
|
||||||
|
|
||||||
|
EVERY_N_EPOCHS = 1 # how often to save. integers >= 1 save at the end of every nth epoch. floats < 1 subdivide the epoch evenly (eg 0.33 = 3 subdivisions)
|
||||||
|
|
||||||
|
class InterruptiblePlugin(BasePlugin):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
print("Interruptible plugin instantiated")
|
||||||
|
self.previous_save_path = None
|
||||||
|
self.every_n_epochs = EVERY_N_EPOCHS
|
||||||
|
|
||||||
|
def on_epoch_start(self, **kwargs):
|
||||||
|
epoch = kwargs['epoch']
|
||||||
|
epoch_length = kwargs['epoch_length']
|
||||||
|
self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length)
|
||||||
|
|
||||||
|
def on_step_end(self, **kwargs):
|
||||||
|
local_step = kwargs['local_step']
|
||||||
|
if local_step in self.steps_to_save_this_epoch:
|
||||||
|
global_step = kwargs['global_step']
|
||||||
|
epoch = kwargs['epoch']
|
||||||
|
project_name = kwargs['project_name']
|
||||||
|
log_folder = kwargs['log_folder']
|
||||||
|
ckpt_name = f"rolling-{project_name}-ep{epoch:02}-gs{global_step:05}"
|
||||||
|
save_path = os.path.join(log_folder, "ckpts", ckpt_name)
|
||||||
|
print(f"{type(self)} saving model to {save_path}")
|
||||||
|
save_model(save_path, global_step=global_step, ed_state=kwargs['ed_state'], save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_full_precision=True, save_optimizer_flag=True)
|
||||||
|
self._remove_previous()
|
||||||
|
self.previous_save_path = save_path
|
||||||
|
|
||||||
|
def on_training_end(self, **kwargs):
|
||||||
|
self._remove_previous()
|
||||||
|
|
||||||
|
def _remove_previous(self):
|
||||||
|
if self.previous_save_path is not None:
|
||||||
|
shutil.rmtree(self.previous_save_path, ignore_errors=True)
|
||||||
|
self.previous_save_path = None
|
||||||
|
|
||||||
|
def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||||
|
if self.every_n_epochs >= 1:
|
||||||
|
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||||
|
# last step only
|
||||||
|
return [epoch_length_steps-1]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||||
|
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||||
|
# validation happens after training:
|
||||||
|
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||||
|
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||||
|
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
|
@ -44,7 +44,7 @@ class Timer:
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
elapsed_time = time.time() - self.start
|
elapsed_time = time.time() - self.start
|
||||||
if elapsed_time > self.warn_seconds:
|
if elapsed_time > self.warn_seconds:
|
||||||
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds')
|
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.warn_seconds} seconds')
|
||||||
|
|
||||||
|
|
||||||
class PluginRunner:
|
class PluginRunner:
|
||||||
|
|
278
train.py
278
train.py
|
@ -27,6 +27,7 @@ import gc
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
@ -102,6 +103,109 @@ def convert_to_hf(ckpt_path):
|
||||||
is_sd1attn, yaml = get_attn_yaml(ckpt_path)
|
is_sd1attn, yaml = get_attn_yaml(ckpt_path)
|
||||||
return ckpt_path, is_sd1attn, yaml
|
return ckpt_path, is_sd1attn, yaml
|
||||||
|
|
||||||
|
class EveryDreamTrainingState:
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: EveryDreamOptimizer,
|
||||||
|
train_batch: EveryDreamBatch,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
scheduler,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
unet_ema: Optional[UNet2DConditionModel],
|
||||||
|
text_encoder_ema: Optional[CLIPTextModel]
|
||||||
|
):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.train_batch = train_batch
|
||||||
|
self.unet = unet
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.vae = vae
|
||||||
|
self.unet_ema = unet_ema
|
||||||
|
self.text_encoder_ema = text_encoder_ema
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name,
|
||||||
|
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||||
|
"""
|
||||||
|
Save the model to disk
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
||||||
|
nonlocal save_ckpt_dir
|
||||||
|
nonlocal save_full_precision
|
||||||
|
nonlocal yaml_name
|
||||||
|
|
||||||
|
if save_ckpt_dir is not None:
|
||||||
|
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||||
|
else:
|
||||||
|
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||||
|
save_ckpt_dir = os.curdir
|
||||||
|
|
||||||
|
half = not save_full_precision
|
||||||
|
|
||||||
|
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||||
|
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||||
|
|
||||||
|
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||||
|
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
|
||||||
|
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||||
|
shutil.copyfile(yaml_name, yaml_save_path)
|
||||||
|
|
||||||
|
|
||||||
|
if global_step is None or global_step == 0:
|
||||||
|
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if args.ema_decay_rate != None:
|
||||||
|
pipeline_ema = StableDiffusionPipeline(
|
||||||
|
vae=ed_state.vae,
|
||||||
|
text_encoder=ed_state.text_encoder_ema,
|
||||||
|
tokenizer=ed_state.tokenizer,
|
||||||
|
unet=ed_state.unet_ema,
|
||||||
|
scheduler=ed_state.scheduler,
|
||||||
|
safety_checker=None, # save vram
|
||||||
|
requires_safety_checker=None, # avoid nag
|
||||||
|
feature_extractor=None, # must be none of no safety checker
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers_model_path = save_path + "_ema"
|
||||||
|
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
|
||||||
|
pipeline_ema.save_pretrained(diffusers_model_path)
|
||||||
|
|
||||||
|
if save_ckpt:
|
||||||
|
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
||||||
|
|
||||||
|
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
||||||
|
|
||||||
|
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
vae=ed_state.vae,
|
||||||
|
text_encoder=ed_state.text_encoder,
|
||||||
|
tokenizer=ed_state.tokenizer,
|
||||||
|
unet=ed_state.unet,
|
||||||
|
scheduler=ed_state.scheduler,
|
||||||
|
safety_checker=None, # save vram
|
||||||
|
requires_safety_checker=None, # avoid nag
|
||||||
|
feature_extractor=None, # must be none of no safety checker
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers_model_path = save_path
|
||||||
|
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
|
||||||
|
pipeline.save_pretrained(diffusers_model_path)
|
||||||
|
|
||||||
|
if save_ckpt:
|
||||||
|
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||||
|
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
||||||
|
|
||||||
|
if save_optimizer_flag:
|
||||||
|
logging.info(f" Saving optimizer state to {save_path}")
|
||||||
|
ed_state.optimizer.save(save_path)
|
||||||
|
|
||||||
|
|
||||||
def setup_local_logger(args):
|
def setup_local_logger(args):
|
||||||
"""
|
"""
|
||||||
configures logger with file and console logging, logs args, and returns the datestamp
|
configures logger with file and console logging, logs args, and returns the datestamp
|
||||||
|
@ -477,95 +581,6 @@ def main(args):
|
||||||
if 'cuda' in original_device.type:
|
if 'cuda' in original_device.type:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
|
||||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
|
||||||
|
|
||||||
nonlocal unet
|
|
||||||
nonlocal text_encoder
|
|
||||||
nonlocal unet_ema
|
|
||||||
nonlocal text_encoder_ema
|
|
||||||
|
|
||||||
"""
|
|
||||||
Save the model to disk
|
|
||||||
"""
|
|
||||||
|
|
||||||
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
|
||||||
nonlocal save_ckpt_dir
|
|
||||||
nonlocal save_full_precision
|
|
||||||
nonlocal yaml_name
|
|
||||||
|
|
||||||
if save_ckpt_dir is not None:
|
|
||||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
|
||||||
else:
|
|
||||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
|
||||||
save_ckpt_dir = os.curdir
|
|
||||||
|
|
||||||
half = not save_full_precision
|
|
||||||
|
|
||||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
|
||||||
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
|
|
||||||
|
|
||||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
|
||||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
|
|
||||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
|
||||||
shutil.copyfile(yaml_name, yaml_save_path)
|
|
||||||
|
|
||||||
|
|
||||||
global global_step
|
|
||||||
|
|
||||||
if global_step is None or global_step == 0:
|
|
||||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
if args.ema_decay_rate != None:
|
|
||||||
pipeline_ema = StableDiffusionPipeline(
|
|
||||||
vae=vae,
|
|
||||||
text_encoder=text_encoder_ema,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
unet=unet_ema,
|
|
||||||
scheduler=scheduler,
|
|
||||||
safety_checker=None, # save vram
|
|
||||||
requires_safety_checker=None, # avoid nag
|
|
||||||
feature_extractor=None, # must be none of no safety checker
|
|
||||||
)
|
|
||||||
|
|
||||||
diffusers_model_path = save_path + "_ema"
|
|
||||||
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
|
|
||||||
pipeline_ema.save_pretrained(diffusers_model_path)
|
|
||||||
|
|
||||||
if save_ckpt:
|
|
||||||
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
|
||||||
|
|
||||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
|
||||||
|
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline(
|
|
||||||
vae=vae,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
unet=unet,
|
|
||||||
scheduler=scheduler,
|
|
||||||
safety_checker=None, # save vram
|
|
||||||
requires_safety_checker=None, # avoid nag
|
|
||||||
feature_extractor=None, # must be none of no safety checker
|
|
||||||
)
|
|
||||||
|
|
||||||
diffusers_model_path = save_path
|
|
||||||
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
|
|
||||||
pipeline.save_pretrained(diffusers_model_path)
|
|
||||||
|
|
||||||
if save_ckpt:
|
|
||||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
|
||||||
|
|
||||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
|
||||||
|
|
||||||
|
|
||||||
if save_optimizer_flag:
|
|
||||||
logging.info(f" Saving optimizer state to {save_path}")
|
|
||||||
ed_optimizer.save(save_path)
|
|
||||||
|
|
||||||
|
|
||||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_strength_target != None)
|
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_strength_target != None)
|
||||||
ema_model_loaded_from_file = False
|
ema_model_loaded_from_file = False
|
||||||
|
@ -574,6 +589,7 @@ def main(args):
|
||||||
ema_device = torch.device(args.ema_device)
|
ema_device = torch.device(args.ema_device)
|
||||||
|
|
||||||
optimizer_state_path = None
|
optimizer_state_path = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# check for a local file
|
# check for a local file
|
||||||
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||||
|
@ -582,10 +598,6 @@ def main(args):
|
||||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||||
|
|
||||||
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
|
||||||
if not os.path.exists(optimizer_state_path):
|
|
||||||
optimizer_state_path = None
|
|
||||||
else:
|
else:
|
||||||
# try to download from HF using resume_ckpt as a repo id
|
# try to download from HF using resume_ckpt as a repo id
|
||||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||||
|
@ -700,7 +712,9 @@ def main(args):
|
||||||
# Make sure correct types are used for models
|
# Make sure correct types are used for models
|
||||||
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
|
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
|
||||||
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
|
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
|
||||||
|
else:
|
||||||
|
unet_ema = None
|
||||||
|
text_encoder_ema = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
#unet = torch.compile(unet)
|
#unet = torch.compile(unet)
|
||||||
|
@ -834,9 +848,9 @@ def main(args):
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||||
__save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae,
|
save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||||
ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer,
|
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||||
save_ckpt=not args.no_save_ckpt)
|
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
exit(_SIGTERM_EXIT_CODE)
|
exit(_SIGTERM_EXIT_CODE)
|
||||||
else:
|
else:
|
||||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||||
|
@ -1031,7 +1045,12 @@ def main(args):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def make_save_path(epoch, global_step, prepend=""):
|
def make_save_path(epoch, global_step, prepend=""):
|
||||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
basename = f"{prepend}{args.project_name}"
|
||||||
|
if epoch is not None:
|
||||||
|
basename += f"-ep{epoch:02}"
|
||||||
|
if global_step is not None:
|
||||||
|
basename += f"-gs{global_step:05}"
|
||||||
|
return os.path.join(log_folder, "ckpts", basename)
|
||||||
|
|
||||||
|
|
||||||
# Pre-train validation to establish a starting point on the loss graph
|
# Pre-train validation to establish a starting point on the loss graph
|
||||||
|
@ -1054,25 +1073,43 @@ def main(args):
|
||||||
from plugins.plugins import PluginRunner
|
from plugins.plugins import PluginRunner
|
||||||
plugin_runner = PluginRunner(plugins=plugins)
|
plugin_runner = PluginRunner(plugins=plugins)
|
||||||
|
|
||||||
|
def make_current_ed_state() -> EveryDreamTrainingState:
|
||||||
|
return EveryDreamTrainingState(optimizer=ed_optimizer,
|
||||||
|
train_batch=train_batch,
|
||||||
|
unet=unet,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
vae=vae,
|
||||||
|
unet_ema=unet_ema,
|
||||||
|
text_encoder_ema=text_encoder_ema)
|
||||||
|
|
||||||
|
epoch = None
|
||||||
try:
|
try:
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||||
|
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
|
||||||
|
|
||||||
for epoch in range(args.max_epochs):
|
for epoch in range(args.max_epochs):
|
||||||
|
|
||||||
if args.load_settings_every_epoch:
|
if args.load_settings_every_epoch:
|
||||||
load_train_json_from_file(args)
|
load_train_json_from_file(args)
|
||||||
|
|
||||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||||
|
|
||||||
|
plugin_runner.run_on_epoch_start(
|
||||||
|
epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
|
epoch_length=epoch_len,
|
||||||
project_name=args.project_name,
|
project_name=args.project_name,
|
||||||
log_folder=log_folder,
|
log_folder=log_folder,
|
||||||
data_root=args.data_root)
|
data_root=args.data_root
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
loss_epoch = []
|
loss_epoch = []
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
images_per_sec_log_step = []
|
images_per_sec_log_step = []
|
||||||
|
|
||||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
|
||||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||||
|
|
||||||
|
@ -1082,13 +1119,16 @@ def main(args):
|
||||||
)
|
)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
plugin_runner.run_on_step_start(epoch=epoch,
|
plugin_runner.run_on_step_start(epoch=epoch,
|
||||||
|
local_step=step,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
project_name=args.project_name,
|
project_name=args.project_name,
|
||||||
log_folder=log_folder,
|
log_folder=log_folder,
|
||||||
batch=batch)
|
batch=batch,
|
||||||
|
ed_state=make_current_ed_state())
|
||||||
|
|
||||||
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
|
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
|
||||||
|
|
||||||
|
@ -1155,27 +1195,29 @@ def main(args):
|
||||||
|
|
||||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||||
|
|
||||||
|
needs_save = False
|
||||||
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
||||||
last_epoch_saved_time = time.time()
|
last_epoch_saved_time = time.time()
|
||||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||||
save_path = make_save_path(epoch, global_step)
|
needs_save = True
|
||||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
|
|
||||||
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
|
|
||||||
save_ckpt=not args.no_save_ckpt)
|
|
||||||
|
|
||||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||||
|
needs_save = True
|
||||||
|
if needs_save:
|
||||||
save_path = make_save_path(epoch, global_step)
|
save_path = make_save_path(epoch, global_step)
|
||||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
|
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||||
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
|
save_ckpt_dir=None, yaml_name=None,
|
||||||
save_ckpt=not args.no_save_ckpt)
|
save_full_precision=args.save_full_precision,
|
||||||
|
save_optimizer_flag=args.save_optimizer, save_ckpt=False)
|
||||||
|
|
||||||
plugin_runner.run_on_step_end(epoch=epoch,
|
plugin_runner.run_on_step_end(epoch=epoch,
|
||||||
global_step=global_step,
|
global_step=global_step,
|
||||||
|
local_step=step,
|
||||||
project_name=args.project_name,
|
project_name=args.project_name,
|
||||||
log_folder=log_folder,
|
log_folder=log_folder,
|
||||||
data_root=args.data_root,
|
data_root=args.data_root,
|
||||||
batch=batch)
|
batch=batch,
|
||||||
|
ed_state=make_current_ed_state())
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -1192,6 +1234,7 @@ def main(args):
|
||||||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||||
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||||
|
|
||||||
|
if len(loss_epoch) > 0:
|
||||||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||||
|
|
||||||
|
@ -1206,9 +1249,13 @@ def main(args):
|
||||||
|
|
||||||
# end of training
|
# end of training
|
||||||
epoch = args.max_epochs
|
epoch = args.max_epochs
|
||||||
|
|
||||||
|
plugin_runner.run_on_training_end()
|
||||||
|
|
||||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
|
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||||
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||||
|
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
|
|
||||||
total_elapsed_time = time.time() - training_start_time
|
total_elapsed_time = time.time() - training_start_time
|
||||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||||
|
@ -1218,8 +1265,9 @@ def main(args):
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||||
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
|
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||||
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||||
|
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||||
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue