implement extra manual validation splits

This commit is contained in:
damian 2023-04-29 17:31:08 +02:00 committed by Victor Hall
parent aad00eab2e
commit 413f981512
3 changed files with 117 additions and 64 deletions

View File

@ -2,6 +2,7 @@ import json
import logging
import math
import random
from dataclasses import dataclass, field
from typing import Callable, Any, Optional, Generator
from argparse import Namespace
@ -20,6 +21,8 @@ from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
from colorama import Fore, Style
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
@ -35,6 +38,25 @@ def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageT
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
@dataclass
class ValidationDataset:
name: str
dataloader: torch.utils.data.DataLoader
loss_history: list[float] = field(default_factory=list)
val_loss_window_size: Optional[int] = 5 # todo: arg for this?
def track_loss_trend(self, mean_loss: float):
if self.val_loss_window_size is None:
return
self.loss_history.append(mean_loss)
if len(self.loss_history) > ((self.val_loss_window_size * 2) + 1):
dy = np.diff(self.loss_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss for {self.name} shows diverging. Check your loss/{self.name} graph.")
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
@ -42,8 +64,7 @@ class EveryDreamValidator:
resolution: int,
log_writer: SummaryWriter,
):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.validation_datasets = []
self.resolution = resolution
self.log_writer = log_writer
@ -54,22 +75,33 @@ class EveryDreamValidator:
'validate_training': True,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
'auto_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15,
'use_relative_loss': False,
'extra_manual_datasets': {
# name: path pairs
# eg "santa suit": "/path/to/captioned_santa_suit_images", will be logged to tensorboard as "loss/santa suit"
}
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
self.train_overlapping_dataloader_loss_offset = None
self.val_loss_offset = None
if self.config.get('val_split_mode') == 'manual':
if 'manual_data_root' not in self.config:
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
if 'val_split_proportion' in self.config:
print(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
f"update your validation config json{Style.RESET_ALL}")
self.config.update({'auto_split_proportion': self.config['val_split_proportion']})
self.loss_val_history = []
self.val_loss_window_size = 5 # todo: arg for this?
@property
def batch_size(self):
@ -96,46 +128,35 @@ class EveryDreamValidator:
"""
with isolate_rng():
random.seed(self.seed)
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
auto_dataset, remaining_train_items = self._build_automatic_validation_dataset_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
train_overlapping_dataset = self._build_train_stabilizer_dataloader_if_required(
remaining_train_items, tokenizer)
if auto_dataset is not None:
self.validation_datasets.append(auto_dataset)
if train_overlapping_dataset is not None:
self.validation_datasets.append(train_overlapping_dataset)
manual_splits = self._build_manual_validation_datasets(tokenizer)
self.validation_datasets.extend(manual_splits)
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
mean_loss = self._calculate_validation_loss('stabilize-train',
self.train_overlapping_dataloader,
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
if self.train_overlapping_dataloader_loss_offset is None:
self.train_overlapping_dataloader_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
if self.val_dataloader is not None:
mean_loss = self._calculate_validation_loss('val',
self.val_dataloader,
get_model_prediction_and_target_callable)
if self.val_loss_offset is None:
self.val_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/val",
scalar_value=mean_loss if not self.use_relative_loss else self.val_loss_offset + mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
self.track_loss_trend(mean_loss)
def track_loss_trend(self, mean_loss):
self.loss_val_history.append(mean_loss)
if len(self.loss_val_history) > ((self.val_loss_window_size * 2) + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging. Check your val/loss graph.")
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
@ -166,31 +187,35 @@ class EveryDreamValidator:
return loss_validation_local
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
def _build_automatic_validation_dataset_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> tuple[Optional[ValidationDataset], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode is None or val_split_mode == 'none':
if val_split_mode is None or val_split_mode == 'none' or val_split_mode == 'manual':
# manual is handled by _build_manual_validation_datasets
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
auto_split_proportion = self.config['auto_split_proportion']
val_items, remaining_train_items = get_random_split(image_train_items, auto_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual':
val_data_root = self.config.get('val_data_root', None)
if val_data_root is None:
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
val_items = self._load_manual_val_split(val_data_root)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
val_ed_batch = self._build_ed_batch(val_items, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return ValidationDataset(name='val', dataloader=val_dataloader), remaining_train_items
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_manual_validation_datasets(self, tokenizer) -> list[ValidationDataset]:
datasets = []
for name, root in self.config.get('extra_manual_datasets', {}).items():
items = self._load_manual_val_split(root)
logging.info(f" * Loaded {len(items)} validation images for validation set '{name}' from {root}")
ed_batch = self._build_ed_batch(items, tokenizer=tokenizer, name=name)
dataloader = build_torch_dataloader(ed_batch, batch_size=self.batch_size)
datasets.append(ValidationDataset(name=name, dataloader=dataloader))
return datasets
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
-> Optional[ValidationDataset]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
@ -198,10 +223,9 @@ class EveryDreamValidator:
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_ed_batch = self._build_ed_batch(stabilize_items, tokenizer=tokenizer, name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
return ValidationDataset(name='stabilize-train', dataloader=stabilize_dataloader, val_loss_window_size=None)
def _load_manual_val_split(self, val_data_root: str):
args = Namespace(
@ -214,7 +238,7 @@ class EveryDreamValidator:
random.shuffle(val_items)
return val_items
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
def _build_ed_batch(self, items: list[ImageTrainItem], tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(

View File

@ -91,10 +91,11 @@ The config file has the following options:
#### Validation settings
* `validate_training`: If `true`, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.
* `val_split_mode`: Either `automatic` or `manual`, ignored if validate_training is false.
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `val_split_proportion`) and removes them from training to use as a validation set.
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `val_data_root`.
* `val_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
* `val_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `auto_split_proportion`) and removes them from training to use as a validation set.
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `manual_data_root`.
* `auto_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
* `manual_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
* `extra_manual_datasets`: Dictionary specifying additional folders containing validation datasets - see "Extra manual datasets" below.
#### Train loss graph stabilization settings
@ -105,3 +106,29 @@ The config file has the following options:
* `every_n_epochs`: How often to run validation (1=every epoch).
* `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`.
#### Extra manual datasets
If you're building a model with multiple training subjects, you may want to specify additional validation datasets so you can check the progress of each part of your model separately. You can do this using the `extra_manual_datasets` property of the validation config .json file.
For example, suppose you're training a model for different dog breeds, and you're especially interested in how well it's training huskies and puggles. To do this, take some of your husky and puggle training data and put it into two separate folders, outside of the data root. For example, suppose you have 100 husky images and 100 puggle images, like this:
```commandline
/workspace/dogs-model-training/data_root/husky <- contains 100 images for training
/workspace/dogs-model-training/data_root/puggle <- contains 100 images for training
```
Take about 15 images from each of the `husky` and `puggle` folders and put them in their own `validation` folder, outside of the `data_root`:
```commandline
/workspace/dogs-model-training/validation/husky <- contains 15 images for validation
/workspace/dogs-model-training/validation/puggle <- contains 15 images for validation
/workspace/dogs-model-training/data_root/husky <- contains the remaining 85 images for training
/workspace/dogs-model-training/data_root/puggle <- contains the remaining 85 images for training
```
Then update your `validation_config.json` file by adding entries to `extra_manual_datasets` to point to these folders:
```commandline
"extra_manual_datasets": {
"husky": "/workspace/dogs-model-training/validation/husky",
"puggle": "/workspace/dogs-model-training/validation/puggle"
}
```
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.

View File

@ -1,9 +1,10 @@
{
"documentation": {
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by auto_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'manual_data_root'.",
"auto_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"manual_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }",
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
@ -12,8 +13,9 @@
},
"validate_training": true,
"val_split_mode": "automatic",
"val_data_root": null,
"val_split_proportion": 0.15,
"auto_split_proportion": 0.15,
"manual_data_root": null,
"extra_manual_datasets" : {},
"stabilize_training_loss": false,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,