log warning when val/loss is showing upward trend
This commit is contained in:
parent
ba87b0cae1
commit
87ffb5413e
|
@ -7,6 +7,7 @@ from typing import Callable, Any, Optional, Generator
|
|||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from colorama import Fore, Style
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -63,6 +64,9 @@ class EveryDreamValidator:
|
|||
with open(val_config_path, 'rt') as f:
|
||||
self.config.update(json.load(f))
|
||||
|
||||
self.loss_val_history = []
|
||||
self.val_loss_window_size = 4 # todo: arg for this?
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self.config['batch_size']
|
||||
|
@ -99,7 +103,14 @@ class EveryDreamValidator:
|
|||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_dataloader is not None:
|
||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
val_loss = self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
|
||||
self.loss_val_history.append(val_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
|
||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
|
@ -127,6 +138,9 @@ class EveryDreamValidator:
|
|||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
return loss_validation_local
|
||||
|
||||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||
|
|
Loading…
Reference in New Issue