Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer
This commit is contained in:
commit
6340b7fdfe
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Any, Optional, Generator
|
||||
from argparse import Namespace
|
||||
|
||||
|
@ -20,6 +21,8 @@ from data import aspects
|
|||
from data.image_train_item import ImageTrainItem
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
|
@ -35,6 +38,25 @@ def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageT
|
|||
for i in items:
|
||||
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationDataset:
|
||||
name: str
|
||||
dataloader: torch.utils.data.DataLoader
|
||||
loss_history: list[float] = field(default_factory=list)
|
||||
val_loss_window_size: Optional[int] = 5 # todo: arg for this?
|
||||
|
||||
def track_loss_trend(self, mean_loss: float):
|
||||
if self.val_loss_window_size is None:
|
||||
return
|
||||
self.loss_history.append(mean_loss)
|
||||
|
||||
if len(self.loss_history) > ((self.val_loss_window_size * 2) + 1):
|
||||
dy = np.diff(self.loss_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss for {self.name} shows diverging. Check your loss/{self.name} graph.")
|
||||
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
|
@ -42,8 +64,7 @@ class EveryDreamValidator:
|
|||
resolution: int,
|
||||
log_writer: SummaryWriter,
|
||||
):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
self.validation_datasets = []
|
||||
self.resolution = resolution
|
||||
self.log_writer = log_writer
|
||||
|
||||
|
@ -54,22 +75,38 @@ class EveryDreamValidator:
|
|||
|
||||
'validate_training': True,
|
||||
'val_split_mode': 'automatic',
|
||||
'val_split_proportion': 0.15,
|
||||
'auto_split_proportion': 0.15,
|
||||
|
||||
'stabilize_training_loss': False,
|
||||
'stabilize_split_proportion': 0.15,
|
||||
|
||||
'use_relative_loss': False,
|
||||
|
||||
'extra_manual_datasets': {
|
||||
# name: path pairs
|
||||
# eg "santa suit": "/path/to/captioned_santa_suit_images", will be logged to tensorboard as "loss/santa suit"
|
||||
}
|
||||
}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
self.config.update(json.load(f))
|
||||
|
||||
self.train_overlapping_dataloader_loss_offset = None
|
||||
self.val_loss_offset = None
|
||||
if 'val_data_root' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_data_root' for 'manual_data_root' - please "
|
||||
f"update your validation config json{Style.RESET_ALL}")
|
||||
self.config.update({'manual_data_root': self.config['val_data_root']})
|
||||
|
||||
if self.config.get('val_split_mode') == 'manual':
|
||||
if 'manual_data_root' not in self.config:
|
||||
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
|
||||
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
||||
|
||||
if 'val_split_proportion' in self.config:
|
||||
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
|
||||
f"update your validation config json{Style.RESET_ALL}")
|
||||
self.config.update({'auto_split_proportion': self.config['val_split_proportion']})
|
||||
|
||||
|
||||
self.loss_val_history = []
|
||||
self.val_loss_window_size = 5 # todo: arg for this?
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
|
@ -96,46 +133,35 @@ class EveryDreamValidator:
|
|||
"""
|
||||
with isolate_rng():
|
||||
random.seed(self.seed)
|
||||
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
||||
|
||||
auto_dataset, remaining_train_items = self._build_automatic_validation_dataset_if_required(train_items, tokenizer)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
|
||||
train_overlapping_dataset = self._build_train_stabilizer_dataloader_if_required(
|
||||
remaining_train_items, tokenizer)
|
||||
|
||||
if auto_dataset is not None:
|
||||
self.validation_datasets.append(auto_dataset)
|
||||
if train_overlapping_dataset is not None:
|
||||
self.validation_datasets.append(train_overlapping_dataset)
|
||||
manual_splits = self._build_manual_validation_datasets(tokenizer)
|
||||
self.validation_datasets.extend(manual_splits)
|
||||
|
||||
return remaining_train_items
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('stabilize-train',
|
||||
self.train_overlapping_dataloader,
|
||||
for i, dataset in enumerate(self.validation_datasets):
|
||||
mean_loss = self._calculate_validation_loss(dataset.name,
|
||||
dataset.dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.train_overlapping_dataloader_loss_offset is None:
|
||||
self.train_overlapping_dataloader_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
|
||||
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
|
||||
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
||||
scalar_value=mean_loss,
|
||||
global_step=global_step)
|
||||
if self.val_dataloader is not None:
|
||||
mean_loss = self._calculate_validation_loss('val',
|
||||
self.val_dataloader,
|
||||
get_model_prediction_and_target_callable)
|
||||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=mean_loss if not self.use_relative_loss else self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
|
||||
dataset.track_loss_trend(mean_loss)
|
||||
|
||||
self.track_loss_trend(mean_loss)
|
||||
|
||||
def track_loss_trend(self, mean_loss):
|
||||
self.loss_val_history.append(mean_loss)
|
||||
|
||||
if len(self.loss_val_history) > ((self.val_loss_window_size * 2) + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging. Check your val/loss graph.")
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
@ -166,31 +192,35 @@ class EveryDreamValidator:
|
|||
return loss_validation_local
|
||||
|
||||
|
||||
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
||||
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
||||
def _build_automatic_validation_dataset_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> tuple[Optional[ValidationDataset], list[ImageTrainItem]]:
|
||||
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
||||
val_split_proportion = self.config['val_split_proportion']
|
||||
remaining_train_items = image_train_items
|
||||
if val_split_mode is None or val_split_mode == 'none':
|
||||
if val_split_mode is None or val_split_mode == 'none' or val_split_mode == 'manual':
|
||||
# manual is handled by _build_manual_validation_datasets
|
||||
return None, image_train_items
|
||||
elif val_split_mode == 'automatic':
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
||||
auto_split_proportion = self.config['auto_split_proportion']
|
||||
val_items, remaining_train_items = get_random_split(image_train_items, auto_split_proportion, batch_size=self.batch_size)
|
||||
val_items = list(disable_multiplier_and_flip(val_items))
|
||||
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
||||
elif val_split_mode == 'manual':
|
||||
val_data_root = self.config.get('val_data_root', None)
|
||||
if val_data_root is None:
|
||||
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
|
||||
val_items = self._load_manual_val_split(val_data_root)
|
||||
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
||||
val_ed_batch = self._build_ed_batch(val_items, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return ValidationDataset(name='val', dataloader=val_dataloader), remaining_train_items
|
||||
else:
|
||||
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
||||
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
||||
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
||||
return val_dataloader, remaining_train_items
|
||||
|
||||
def _build_manual_validation_datasets(self, tokenizer) -> list[ValidationDataset]:
|
||||
datasets = []
|
||||
for name, root in self.config.get('extra_manual_datasets', {}).items():
|
||||
items = self._load_manual_val_split(root)
|
||||
logging.info(f" * Loaded {len(items)} validation images for validation set '{name}' from {root}")
|
||||
ed_batch = self._build_ed_batch(items, tokenizer=tokenizer, name=name)
|
||||
dataloader = build_torch_dataloader(ed_batch, batch_size=self.batch_size)
|
||||
datasets.append(ValidationDataset(name=name, dataloader=dataloader))
|
||||
return datasets
|
||||
|
||||
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
||||
-> Optional[torch.utils.data.DataLoader]:
|
||||
-> Optional[ValidationDataset]:
|
||||
stabilize_training_loss = self.config['stabilize_training_loss']
|
||||
if not stabilize_training_loss:
|
||||
return None
|
||||
|
@ -198,10 +228,9 @@ class EveryDreamValidator:
|
|||
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
||||
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
||||
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
||||
name='stabilize-train')
|
||||
stabilize_ed_batch = self._build_ed_batch(stabilize_items, tokenizer=tokenizer, name='stabilize-train')
|
||||
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
||||
return stabilize_dataloader
|
||||
return ValidationDataset(name='stabilize-train', dataloader=stabilize_dataloader, val_loss_window_size=None)
|
||||
|
||||
def _load_manual_val_split(self, val_data_root: str):
|
||||
args = Namespace(
|
||||
|
@ -214,7 +243,7 @@ class EveryDreamValidator:
|
|||
random.shuffle(val_items)
|
||||
return val_items
|
||||
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
||||
def _build_ed_batch(self, items: list[ImageTrainItem], tokenizer, name='val'):
|
||||
batch_size = self.batch_size
|
||||
seed = self.seed
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
|
|
|
@ -91,10 +91,11 @@ The config file has the following options:
|
|||
#### Validation settings
|
||||
* `validate_training`: If `true`, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.
|
||||
* `val_split_mode`: Either `automatic` or `manual`, ignored if validate_training is false.
|
||||
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `val_split_proportion`) and removes them from training to use as a validation set.
|
||||
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `val_data_root`.
|
||||
* `val_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
|
||||
* `val_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
|
||||
* `automatic` val_split_mode picks a random subset of the training set (the number of items is controlled by `auto_split_proportion`) and removes them from training to use as a validation set.
|
||||
* `manual` val_split_mode lets you provide your own folder of validation items (images and captions), specified using `manual_data_root`.
|
||||
* `auto_split_proportion`: For `automatic` val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.
|
||||
* `manual_data_root`: For `manual` val_split_mode, the path to a folder containing validation items.
|
||||
* `extra_manual_datasets`: Dictionary specifying additional folders containing validation datasets - see "Extra manual datasets" below.
|
||||
|
||||
#### Train loss graph stabilization settings
|
||||
|
||||
|
@ -105,3 +106,29 @@ The config file has the following options:
|
|||
|
||||
* `every_n_epochs`: How often to run validation (1=every epoch).
|
||||
* `seed`: The seed to use when running validation passes, and also for picking subsets of the data to use with `automatic` val_split_mode and/or `stabilize_training_loss`.
|
||||
|
||||
#### Extra manual datasets
|
||||
|
||||
If you're building a model with multiple training subjects, you may want to specify additional validation datasets so you can check the progress of each part of your model separately. You can do this using the `extra_manual_datasets` property of the validation config .json file.
|
||||
|
||||
For example, suppose you're training a model for different dog breeds, and you're especially interested in how well it's training huskies and puggles. To do this, take some of your husky and puggle training data and put it into two separate folders, outside of the data root. For example, suppose you have 100 husky images and 100 puggle images, like this:
|
||||
```commandline
|
||||
/workspace/dogs-model-training/data_root/husky <- contains 100 images for training
|
||||
/workspace/dogs-model-training/data_root/puggle <- contains 100 images for training
|
||||
```
|
||||
Take about 15 images from each of the `husky` and `puggle` folders and put them in their own `validation` folder, outside of the `data_root`:
|
||||
```commandline
|
||||
/workspace/dogs-model-training/validation/husky <- contains 15 images for validation
|
||||
/workspace/dogs-model-training/validation/puggle <- contains 15 images for validation
|
||||
/workspace/dogs-model-training/data_root/husky <- contains the remaining 85 images for training
|
||||
/workspace/dogs-model-training/data_root/puggle <- contains the remaining 85 images for training
|
||||
```
|
||||
Then update your `validation_config.json` file by adding entries to `extra_manual_datasets` to point to these folders:
|
||||
```commandline
|
||||
"extra_manual_datasets": {
|
||||
"husky": "/workspace/dogs-model-training/validation/husky",
|
||||
"puggle": "/workspace/dogs-model-training/validation/puggle"
|
||||
}
|
||||
```
|
||||
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.
|
||||
When you run training, you'll now get two additional graphs, `loss/husky` and `loss/puggle` that show the progress for your `husky` and `puggle` training data.
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by auto_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'manual_data_root'.",
|
||||
"auto_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"manual_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"extra_manual_datasets": "Dictionary of 'name':'path' pairs defining additional validation datasets to load and log. eg { 'santa_suit': '/path/to/captioned_santa_suit_images', 'flamingo_suit': '/path/to/flamingo_suit_images' }",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
|
@ -12,8 +13,9 @@
|
|||
},
|
||||
"validate_training": true,
|
||||
"val_split_mode": "automatic",
|
||||
"val_data_root": null,
|
||||
"val_split_proportion": 0.15,
|
||||
"auto_split_proportion": 0.15,
|
||||
"manual_data_root": null,
|
||||
"extra_manual_datasets" : {},
|
||||
"stabilize_training_loss": false,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
|
|
Loading…
Reference in New Issue