From 87ffb5413e1962ad8b465be2fb3f0199180df629 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sat, 4 Mar 2023 15:08:47 -0500 Subject: [PATCH] log warning when val/loss is showing upward trend --- data/every_dream_validation.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 95f5afd..49105d2 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -7,6 +7,7 @@ from typing import Callable, Any, Optional, Generator from argparse import Namespace import torch +import numpy as np from colorama import Fore, Style import torch.nn.functional as F from torch.utils.data import DataLoader @@ -63,6 +64,9 @@ class EveryDreamValidator: with open(val_config_path, 'rt') as f: self.config.update(json.load(f)) + self.loss_val_history = [] + self.val_loss_window_size = 4 # todo: arg for this? + @property def batch_size(self): return self.config['batch_size'] @@ -99,7 +103,14 @@ class EveryDreamValidator: self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable) if self.val_dataloader is not None: - self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) + val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable) + + self.loss_val_history.append(val_loss) + if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1): + dy = np.diff(self.loss_val_history[-self.val_loss_window_size:]) + if np.average(dy) > 0: + logging.warning(f"Validation loss shows diverging") + # todo: signal stop? def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[ [Any, Any], tuple[torch.Tensor, torch.Tensor]]): @@ -127,6 +138,9 @@ class EveryDreamValidator: loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch) self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step) + return loss_validation_local + + def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None