cleanup
This commit is contained in:
parent
74c21602fc
commit
3579bec540
|
@ -54,95 +54,3 @@ class InterruptiblePlugin(BasePlugin):
|
|||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
||||
|
||||
"""
|
||||
class InterruptiblePlugin_(BasePlugin):
|
||||
def __init__(self, log_folder, args):
|
||||
self.log_folder = log_folder
|
||||
self.project_name = args.project_name
|
||||
self.max_epochs = args.max_epochs
|
||||
|
||||
self.every_n_epochs = 1
|
||||
|
||||
|
||||
@classmethod
|
||||
def make_resume_path(cls, resume_ckpt_folder):
|
||||
return os.path.join(resume_ckpt_folder, 'resumable_data.pt')
|
||||
|
||||
def load_resume_state(self, resume_ckpt_path: str, ed_state: EveryDreamTrainingState):
|
||||
resume_path = self.make_resume_path(resume_ckpt_path)
|
||||
try:
|
||||
with open(resume_path, 'rb') as f:
|
||||
resumable_data = torch.load(f)
|
||||
ed_state.optimizer.load_state_dict(resumable_data['ed_optimizer'])
|
||||
ed_state.train_batch.load_state_dict(resumable_data['ed_batch'])
|
||||
except Exception as e:
|
||||
print(f"InterruptiblePlugin unable to load resume state from {resume_path}: {e}")
|
||||
return
|
||||
|
||||
|
||||
def on_epoch_start(self, ed_state: EveryDreamTrainingState, **kwargs):
|
||||
epoch = kwargs['epoch']
|
||||
epoch_length = kwargs['epoch_length']
|
||||
if epoch == 0:
|
||||
resume_ckpt_path = kwargs['resume_ckpt_path']
|
||||
self.load_resume_state(resume_ckpt_path, ed_state)
|
||||
self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length)
|
||||
|
||||
def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
||||
|
||||
def on_step_end(self, epoch: int, local_step: int, global_step: int, ed_state: EveryDreamTrainingState):
|
||||
if local_step in self.steps_to_save_this_epoch:
|
||||
self.save_and_remove_prior(epoch, global_step, ed_state)
|
||||
|
||||
def _save_and_remove_prior(self, epoch: int, global_step: int, ed_state: EveryDreamTrainingState):
|
||||
rolling_save_path = self.make_save_path(epoch, global_step, prepend="rolling-")
|
||||
ed_optimizer: EveryDreamOptimizer = ed_state.optimizer
|
||||
save_model(rolling_save_path,
|
||||
ed_state=ed_state, save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_optimizer_flag=True)
|
||||
|
||||
kwargs['unet'], kwargs['text_encoder'], kwargs['tokenizer'],
|
||||
kwargs['noise_scheduler'], kwargs['vae'], ed_optimizer,
|
||||
save_ckpt_dir=None, yaml_name=None, save_optimizer_flag=True, save_ckpt=False)
|
||||
|
||||
train_batch: EveryDreamBatch = kwargs['train_batch']
|
||||
resumable_data = {
|
||||
'grad_scaler': ed_optimizer.scaler.state_dict(),
|
||||
'epoch': epoch,
|
||||
'global_step': global_step,
|
||||
'train_batch': train_batch.state_dict()
|
||||
}
|
||||
if ed_optimizer.lr_scheduler_te is not None:
|
||||
resumable_data['lr_scheduler_te'] = ed_optimizer.lr_scheduler_te.state_dict()
|
||||
if ed_optimizer.lr_scheduler_unet is not None:
|
||||
resumable_data['lr_scheduler_unet'] = ed_optimizer.lr_scheduler_unet.state_dict()
|
||||
|
||||
torch.save(resumable_data, os.path.join(rolling_save_path, 'resumable_data.pt'))
|
||||
|
||||
self.prev_epoch = epoch
|
||||
self.prev_global_step = global_step
|
||||
if epoch > 0:
|
||||
prev_rolling_save_path = self.make_save_path(epoch, self.prev_global_step, prepend="rolling-")
|
||||
shutil.rmtree(prev_rolling_save_path, ignore_errors=True)
|
||||
|
||||
pass
|
||||
|
||||
def make_save_path(self, epoch, global_step, prepend: str="") -> str:
|
||||
basename = f"{prepend}{self.project_name}-ep{epoch:02}"
|
||||
if global_step is not None:
|
||||
basename += f"-gs{global_step:05}"
|
||||
return os.path.join(self.log_folder, "ckpts", basename)
|
||||
"""
|
Loading…
Reference in New Issue