log warning when val/loss is showing upward trend

This commit is contained in:
Victor Hall 2023-03-04 15:08:47 -05:00
parent ba87b0cae1
commit 87ffb5413e
1 changed files with 15 additions and 1 deletions

View File

@ -7,6 +7,7 @@ from typing import Callable, Any, Optional, Generator
from argparse import Namespace
import torch
import numpy as np
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
@ -63,6 +64,9 @@ class EveryDreamValidator:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
self.loss_val_history = []
self.val_loss_window_size = 4 # todo: arg for this?
@property
def batch_size(self):
return self.config['batch_size']
@ -99,7 +103,14 @@ class EveryDreamValidator:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
self.loss_val_history.append(val_loss)
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging")
# todo: signal stop?
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
@ -127,6 +138,9 @@ class EveryDreamValidator:
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
return loss_validation_local
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None