diff --git a/plugins/interruptible.py b/plugins/interruptible.py index e011523..22c8507 100644 --- a/plugins/interruptible.py +++ b/plugins/interruptible.py @@ -54,95 +54,3 @@ class InterruptiblePlugin(BasePlugin): # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 validate_every_n_steps = epoch_length_steps / num_divisions return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] - -""" -class InterruptiblePlugin_(BasePlugin): - def __init__(self, log_folder, args): - self.log_folder = log_folder - self.project_name = args.project_name - self.max_epochs = args.max_epochs - - self.every_n_epochs = 1 - - - @classmethod - def make_resume_path(cls, resume_ckpt_folder): - return os.path.join(resume_ckpt_folder, 'resumable_data.pt') - - def load_resume_state(self, resume_ckpt_path: str, ed_state: EveryDreamTrainingState): - resume_path = self.make_resume_path(resume_ckpt_path) - try: - with open(resume_path, 'rb') as f: - resumable_data = torch.load(f) - ed_state.optimizer.load_state_dict(resumable_data['ed_optimizer']) - ed_state.train_batch.load_state_dict(resumable_data['ed_batch']) - except Exception as e: - print(f"InterruptiblePlugin unable to load resume state from {resume_path}: {e}") - return - - - def on_epoch_start(self, ed_state: EveryDreamTrainingState, **kwargs): - epoch = kwargs['epoch'] - epoch_length = kwargs['epoch_length'] - if epoch == 0: - resume_ckpt_path = kwargs['resume_ckpt_path'] - self.load_resume_state(resume_ckpt_path, ed_state) - self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length) - - def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]: - if self.every_n_epochs >= 1: - if ((epoch+1) % self.every_n_epochs) == 0: - # last step only - return [epoch_length_steps-1] - else: - return [] - else: - # subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps - num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs))) - # validation happens after training: - # if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99 - validate_every_n_steps = epoch_length_steps / num_divisions - return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)] - - def on_step_end(self, epoch: int, local_step: int, global_step: int, ed_state: EveryDreamTrainingState): - if local_step in self.steps_to_save_this_epoch: - self.save_and_remove_prior(epoch, global_step, ed_state) - - def _save_and_remove_prior(self, epoch: int, global_step: int, ed_state: EveryDreamTrainingState): - rolling_save_path = self.make_save_path(epoch, global_step, prepend="rolling-") - ed_optimizer: EveryDreamOptimizer = ed_state.optimizer - save_model(rolling_save_path, - ed_state=ed_state, save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_optimizer_flag=True) - - kwargs['unet'], kwargs['text_encoder'], kwargs['tokenizer'], - kwargs['noise_scheduler'], kwargs['vae'], ed_optimizer, - save_ckpt_dir=None, yaml_name=None, save_optimizer_flag=True, save_ckpt=False) - - train_batch: EveryDreamBatch = kwargs['train_batch'] - resumable_data = { - 'grad_scaler': ed_optimizer.scaler.state_dict(), - 'epoch': epoch, - 'global_step': global_step, - 'train_batch': train_batch.state_dict() - } - if ed_optimizer.lr_scheduler_te is not None: - resumable_data['lr_scheduler_te'] = ed_optimizer.lr_scheduler_te.state_dict() - if ed_optimizer.lr_scheduler_unet is not None: - resumable_data['lr_scheduler_unet'] = ed_optimizer.lr_scheduler_unet.state_dict() - - torch.save(resumable_data, os.path.join(rolling_save_path, 'resumable_data.pt')) - - self.prev_epoch = epoch - self.prev_global_step = global_step - if epoch > 0: - prev_rolling_save_path = self.make_save_path(epoch, self.prev_global_step, prepend="rolling-") - shutil.rmtree(prev_rolling_save_path, ignore_errors=True) - - pass - - def make_save_path(self, epoch, global_step, prepend: str="") -> str: - basename = f"{prepend}{self.project_name}-ep{epoch:02}" - if global_step is not None: - basename += f"-gs{global_step:05}" - return os.path.join(self.log_folder, "ckpts", basename) -""" \ No newline at end of file