allow sub-epoch validation when every_n_epochs <1
This commit is contained in:
parent
f0111a6e2b
commit
103ab20696
|
@ -97,34 +97,44 @@ class EveryDreamValidator:
|
|||
remaining_train_items, tokenizer)
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
def get_validation_step_indices(self, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = min(epoch_length_steps, math.ceil(1/self.every_n_epochs))
|
||||
# validation happens after training:
|
||||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
||||
|
||||
def do_validation(self, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||
self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.train_overlapping_dataloader_loss_offset is None:
|
||||
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
if self.val_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('val',
|
||||
self.val_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
self.loss_val_history.append(mean_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||
self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.train_overlapping_dataloader_loss_offset is None:
|
||||
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
if self.val_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('val',
|
||||
self.val_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
self.loss_val_history.append(mean_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
|
16
train.py
16
train.py
|
@ -780,8 +780,8 @@ def main(args):
|
|||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
validator.do_validation(global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
|
||||
# the sample generator might be configured to generate samples before step 0
|
||||
if sample_generator.generate_pretrain_samples:
|
||||
|
@ -800,6 +800,12 @@ def main(args):
|
|||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
validation_steps = (
|
||||
[] if validator is None
|
||||
else validator.get_validation_step_indices(len(train_dataloader))
|
||||
)
|
||||
print(f"validation on steps {validation_steps}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
step_start_time = time.time()
|
||||
|
||||
|
@ -860,6 +866,9 @@ def main(args):
|
|||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if validator and step in validation_steps:
|
||||
validator.do_validation(global_step, get_model_prediction_and_target)
|
||||
|
||||
if (global_step + 1) % sample_generator.sample_steps == 0:
|
||||
generate_samples(global_step=global_step, batch=batch)
|
||||
|
||||
|
@ -895,9 +904,6 @@ def main(args):
|
|||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
|
Loading…
Reference in New Issue