diff --git a/plugins/interruptible.py b/plugins/interruptible.py new file mode 100644 index 0000000..13dbf6b --- /dev/null +++ b/plugins/interruptible.py @@ -0,0 +1,148 @@ +import math +import os +import shutil +from plugins.plugins import BasePlugin +from train import save_model + +EVERY_N_EPOCHS = 0.3 # how often to save. integers >= 1 save at the end of every nth epoch. floats < 1 subdivide the epoch evenly (eg 0.33 = 3 subdivisions) + +class InterruptiblePlugin(BasePlugin): + + def __init__(self): + print("Interruptible plugin instantiated") + self.previous_save_path = None + self.every_n_epochs = EVERY_N_EPOCHS + + def on_epoch_start(self, **kwargs): + epoch = kwargs['epoch'] + epoch_length = kwargs['epoch_length'] + self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length) + + def on_step_end(self, **kwargs): + local_step = kwargs['local_step'] + if local_step in self.steps_to_save_this_epoch: + global_step = kwargs['global_step'] + epoch = kwargs['epoch'] + project_name = kwargs['project_name'] + log_folder = kwargs['log_folder'] + ckpt_name = f"rolling-{project_name}-ep{epoch:02}-gs{global_step:05}" + save_path = os.path.join(log_folder, "ckpts", ckpt_name) + print(f"{type(self)} saving model to {save_path}") + save_model(save_path, global_step=global_step, ed_state=kwargs['ed_state'], save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_full_precision=True, save_optimizer_flag=True) + self._remove_previous() + self.previous_save_path = save_path + + def on_training_end(self, **kwargs): + self._remove_previous() + + def _remove_previous(self): + if self.previous_save_path is not None: + shutil.rmtree(self.previous_save_path, ignore_errors=True) + self.previous_save_path = None + + def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: + if self.every_n_epochs >= 1: + if ((epoch+1) % self.every_n_epochs) == 0: + # last step only + return [epoch_length_steps-1] + else: + return [] + else: + # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps + num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) + # validation happens after training: + # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 + validate_every_n_steps = epoch_length_steps / num_divisions + return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] + +""" +class InterruptiblePlugin_(BasePlugin): + def __init__(self, log_folder, args): + self.log_folder = log_folder + self.project_name = args.project_name + self.max_epochs = args.max_epochs + + self.every_n_epochs = 1 + + + @classmethod + def make_resume_path(cls, resume_ckpt_folder): + return os.path.join(resume_ckpt_folder, 'resumable_data.pt') + + def load_resume_state(self, resume_ckpt_path: str, ed_state: EveryDreamTrainingState): + resume_path = self.make_resume_path(resume_ckpt_path) + try: + with open(resume_path, 'rb') as f: + resumable_data = torch.load(f) + ed_state.optimizer.load_state_dict(resumable_data['ed_optimizer']) + ed_state.train_batch.load_state_dict(resumable_data['ed_batch']) + except Exception as e: + print(f"InterruptiblePlugin unable to load resume state from {resume_path}: {e}") + return + + + def on_epoch_start(self, ed_state: EveryDreamTrainingState, **kwargs): + epoch = kwargs['epoch'] + epoch_length = kwargs['epoch_length'] + if epoch == 0: + resume_ckpt_path = kwargs['resume_ckpt_path'] + self.load_resume_state(resume_ckpt_path, ed_state) + self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length) + + def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: + if self.every_n_epochs >= 1: + if ((epoch+1) % self.every_n_epochs) == 0: + # last step only + return [epoch_length_steps-1] + else: + return [] + else: + # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps + num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) + # validation happens after training: + # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 + validate_every_n_steps = epoch_length_steps / num_divisions + return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] + + def on_step_end(self, epoch: int, local_step: int, global_step: int, ed_state: EveryDreamTrainingState): + if local_step in self.steps_to_save_this_epoch: + self.save_and_remove_prior(epoch, global_step, ed_state) + + def _save_and_remove_prior(self, epoch: int, global_step: int, ed_state: EveryDreamTrainingState): + rolling_save_path = self.make_save_path(epoch, global_step, prepend="rolling-") + ed_optimizer: EveryDreamOptimizer = ed_state.optimizer + save_model(rolling_save_path, + ed_state=ed_state, save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_optimizer_flag=True) + + kwargs['unet'], kwargs['text_encoder'], kwargs['tokenizer'], + kwargs['noise_scheduler'], kwargs['vae'], ed_optimizer, + save_ckpt_dir=None, yaml_name=None, save_optimizer_flag=True, save_ckpt=False) + + train_batch: EveryDreamBatch = kwargs['train_batch'] + resumable_data = { + 'grad_scaler': ed_optimizer.scaler.state_dict(), + 'epoch': epoch, + 'global_step': global_step, + 'train_batch': train_batch.state_dict() + } + if ed_optimizer.lr_scheduler_te is not None: + resumable_data['lr_scheduler_te'] = ed_optimizer.lr_scheduler_te.state_dict() + if ed_optimizer.lr_scheduler_unet is not None: + resumable_data['lr_scheduler_unet'] = ed_optimizer.lr_scheduler_unet.state_dict() + + torch.save(resumable_data, os.path.join(rolling_save_path, 'resumable_data.pt')) + + self.prev_epoch = epoch + self.prev_global_step = global_step + if epoch > 0: + prev_rolling_save_path = self.make_save_path(epoch, self.prev_global_step, prepend="rolling-") + shutil.rmtree(prev_rolling_save_path, ignore_errors=True) + + pass + + def make_save_path(self, epoch, global_step, prepend: str="") -> str: + basename = f"{prepend}{self.project_name}-ep{epoch:02}" + if global_step is not None: + basename += f"-gs{global_step:05}" + return os.path.join(self.log_folder, "ckpts", basename) +""" \ No newline at end of file diff --git a/plugins/plugins.py b/plugins/plugins.py index 310905c..dc65ff2 100644 --- a/plugins/plugins.py +++ b/plugins/plugins.py @@ -44,7 +44,7 @@ class Timer: def __exit__(self, type, value, traceback): elapsed_time = time.time() - self.start if elapsed_time > self.warn_seconds: - logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds') + logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.warn_seconds} seconds') class PluginRunner: diff --git a/train.py b/train.py index 2fdd944..2e4fbed 100644 --- a/train.py +++ b/train.py @@ -27,6 +27,7 @@ import gc import random import traceback import shutil +from typing import Optional import torch.nn.functional as F from torch.cuda.amp import autocast @@ -102,6 +103,109 @@ def convert_to_hf(ckpt_path): is_sd1attn, yaml = get_attn_yaml(ckpt_path) return ckpt_path, is_sd1attn, yaml +class EveryDreamTrainingState: + def __init__(self, + optimizer: EveryDreamOptimizer, + train_batch: EveryDreamBatch, + unet: UNet2DConditionModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + scheduler, + vae: AutoencoderKL, + unet_ema: Optional[UNet2DConditionModel], + text_encoder_ema: Optional[CLIPTextModel] + ): + self.optimizer = optimizer + self.train_batch = train_batch + self.unet = unet + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.scheduler = scheduler + self.vae = vae + self.unet_ema = unet_ema, + self.text_encoder = text_encoder_ema + + +@torch.no_grad() +def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name, + save_full_precision=False, save_optimizer_flag=False, save_ckpt=True): + """ + Save the model to disk + """ + + def save_ckpt_file(diffusers_model_path, sd_ckpt_path): + nonlocal save_ckpt_dir + nonlocal save_full_precision + nonlocal yaml_name + + if save_ckpt_dir is not None: + sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) + else: + sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path) + save_ckpt_dir = os.curdir + + half = not save_full_precision + + logging.info(f" * Saving SD model to {sd_ckpt_full}") + converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half) + + if yaml_name and yaml_name != "v1-inference.yaml": + yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml" + logging.info(f" * Saving yaml to {yaml_save_path}") + shutil.copyfile(yaml_name, yaml_save_path) + + + if global_step is None or global_step == 0: + logging.warning(" No model to save, something likely blew up on startup, not saving") + return + + + if args.ema_decay_rate != None: + pipeline_ema = StableDiffusionPipeline( + vae=ed_state.vae, + text_encoder=ed_state.text_encoder_ema, + tokenizer=ed_state.tokenizer, + unet=ed_state.unet_ema, + scheduler=ed_state.scheduler, + safety_checker=None, # save vram + requires_safety_checker=None, # avoid nag + feature_extractor=None, # must be none of no safety checker + ) + + diffusers_model_path = save_path + "_ema" + logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}") + pipeline_ema.save_pretrained(diffusers_model_path) + + if save_ckpt: + sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt" + + save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema) + + + pipeline = StableDiffusionPipeline( + vae=ed_state.vae, + text_encoder=ed_state.text_encoder, + tokenizer=ed_state.tokenizer, + unet=ed_state.unet, + scheduler=ed_state.scheduler, + safety_checker=None, # save vram + requires_safety_checker=None, # avoid nag + feature_extractor=None, # must be none of no safety checker + ) + + diffusers_model_path = save_path + logging.info(f" * Saving diffusers model to {diffusers_model_path}") + pipeline.save_pretrained(diffusers_model_path) + + if save_ckpt: + sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt" + save_ckpt_file(diffusers_model_path, sd_ckpt_path) + + if save_optimizer_flag: + logging.info(f" Saving optimizer state to {save_path}") + ed_state.optimizer.save(save_path) + + def setup_local_logger(args): """ configures logger with file and console logging, logs args, and returns the datestamp @@ -478,95 +582,6 @@ def main(args): if 'cuda' in original_device.type: torch.cuda.empty_cache() - @torch.no_grad() - def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name, - save_full_precision=False, save_optimizer_flag=False, save_ckpt=True): - - nonlocal unet - nonlocal text_encoder - nonlocal unet_ema - nonlocal text_encoder_ema - - """ - Save the model to disk - """ - - def save_ckpt_file(diffusers_model_path, sd_ckpt_path): - nonlocal save_ckpt_dir - nonlocal save_full_precision - nonlocal yaml_name - - if save_ckpt_dir is not None: - sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path) - else: - sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path) - save_ckpt_dir = os.curdir - - half = not save_full_precision - - logging.info(f" * Saving SD model to {sd_ckpt_full}") - converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half) - - if yaml_name and yaml_name != "v1-inference.yaml": - yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml" - logging.info(f" * Saving yaml to {yaml_save_path}") - shutil.copyfile(yaml_name, yaml_save_path) - - - global global_step - - if global_step is None or global_step == 0: - logging.warning(" No model to save, something likely blew up on startup, not saving") - return - - - if args.ema_decay_rate != None: - pipeline_ema = StableDiffusionPipeline( - vae=vae, - text_encoder=text_encoder_ema, - tokenizer=tokenizer, - unet=unet_ema, - scheduler=scheduler, - safety_checker=None, # save vram - requires_safety_checker=None, # avoid nag - feature_extractor=None, # must be none of no safety checker - ) - - diffusers_model_path = save_path + "_ema" - logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}") - pipeline_ema.save_pretrained(diffusers_model_path) - - if save_ckpt: - sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt" - - save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema) - - - pipeline = StableDiffusionPipeline( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, # save vram - requires_safety_checker=None, # avoid nag - feature_extractor=None, # must be none of no safety checker - ) - - diffusers_model_path = save_path - logging.info(f" * Saving diffusers model to {diffusers_model_path}") - pipeline.save_pretrained(diffusers_model_path) - - if save_ckpt: - sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt" - - save_ckpt_file(diffusers_model_path, sd_ckpt_path) - - - if save_optimizer_flag: - logging.info(f" Saving optimizer state to {save_path}") - ed_optimizer.save(save_path) - use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_strength_target != None) ema_model_loaded_from_file = False @@ -575,6 +590,7 @@ def main(args): ema_device = torch.device(args.ema_device) optimizer_state_path = None + try: # check for a local file hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt) @@ -583,10 +599,6 @@ def main(args): text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet") - - optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt") - if not os.path.exists(optimizer_state_path): - optimizer_state_path = None else: # try to download from HF using resume_ckpt as a repo id downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt) @@ -701,7 +713,9 @@ def main(args): # Make sure correct types are used for models unet_ema = unet_ema.to(ema_device, dtype=unet.dtype) text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype) - + else: + unet_ema = None + text_encoder_ema = None try: #unet = torch.compile(unet) @@ -835,9 +849,9 @@ def main(args): logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}") logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}") time.sleep(2) # give opportunity to ctrl-C again to cancel save - __save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae, - ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, - save_ckpt=not args.no_save_ckpt) + save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(), + save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) exit(_SIGTERM_EXIT_CODE) else: # non-main threads (i.e. dataloader workers) should exit cleanly @@ -1032,7 +1046,12 @@ def main(args): torch.cuda.empty_cache() def make_save_path(epoch, global_step, prepend=""): - return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}") + basename = f"{prepend}{args.project_name}" + if epoch is not None: + basename += f"-ep{epoch:02}" + if global_step is not None: + basename += f"-gs{global_step:05}" + return os.path.join(log_folder, "ckpts", basename) @@ -1057,26 +1076,42 @@ def main(args): from plugins.plugins import PluginRunner plugin_runner = PluginRunner(plugins=plugins) + def make_current_ed_state() -> EveryDreamTrainingState: + return EveryDreamTrainingState(optimizer=ed_optimizer, + train_batch=train_batch, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=noise_scheduler, + vae=vae, + unet_ema=unet_ema, + text_encoder_ema=text_encoder_ema) + + epoch = None try: write_batch_schedule(args, log_folder, train_batch, epoch = 0) + plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name) for epoch in range(args.max_epochs): if args.load_settings_every_epoch: load_train_json_from_file(args) + epoch_len = math.ceil(len(train_batch) / args.batch_size) - plugin_runner.run_on_epoch_start(epoch=epoch, - global_step=global_step, - project_name=args.project_name, - log_folder=log_folder, - data_root=args.data_root) + plugin_runner.run_on_epoch_start( + epoch=epoch, + global_step=global_step, + epoch_length=epoch_len, + project_name=args.project_name, + log_folder=log_folder, + data_root=args.data_root + ) loss_epoch = [] epoch_start_time = time.time() images_per_sec_log_step = [] - epoch_len = math.ceil(len(train_batch) / args.batch_size) steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True) steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}") @@ -1086,12 +1121,15 @@ def main(args): ) for step, batch in enumerate(train_dataloader): + step_start_time = time.time() plugin_runner.run_on_step_start(epoch=epoch, + local_step=step, global_step=global_step, project_name=args.project_name, log_folder=log_folder, - batch=batch) + batch=batch, + ed_state=make_current_ed_state()) model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True) @@ -1158,27 +1196,29 @@ def main(args): min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 + needs_save = False if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes): last_epoch_saved_time = time.time() logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}") - save_path = make_save_path(epoch, global_step) - __save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, - args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, - save_ckpt=not args.no_save_ckpt) - + needs_save = True if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs: logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}") + needs_save = True + if needs_save: save_path = make_save_path(epoch, global_step) - __save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, - args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, - save_ckpt=not args.no_save_ckpt) + save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), + save_ckpt_dir=None, yaml_name=None, + save_full_precision=args.save_full_precision, + save_optimizer_flag=args.save_optimizer, save_ckpt=False) plugin_runner.run_on_step_end(epoch=epoch, global_step=global_step, + local_step=step, project_name=args.project_name, log_folder=log_folder, data_root=args.data_root, - batch=batch) + batch=batch, + ed_state=make_current_ed_state()) del batch global_step += 1 @@ -1195,8 +1235,9 @@ def main(args): train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs) write_batch_schedule(args, log_folder, train_batch, epoch + 1) - loss_epoch = sum(loss_epoch) / len(loss_epoch) - log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step) + if len(loss_epoch) > 0: + loss_epoch = sum(loss_epoch) / len(loss_epoch) + log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step) plugin_runner.run_on_epoch_end(epoch=epoch, global_step=global_step, @@ -1209,9 +1250,13 @@ def main(args): # end of training epoch = args.max_epochs + + plugin_runner.run_on_training_end() + save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-")) - __save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, - yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), + save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) total_elapsed_time = time.time() - training_start_time logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}") @@ -1221,8 +1266,9 @@ def main(args): except Exception as ex: logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}") save_path = make_save_path(epoch, global_step, prepend="errored-") - __save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, - yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt) + save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(), + save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision, + save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt) logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}") raise ex