GH-36: Add support for validation split (WIP)

Co-authored-by: Damian Stewart <office@damianstewart.com>
This commit is contained in:
Joel Holdbrooks 2023-02-06 22:10:34 -08:00
parent 85f19b9a2f
commit 41c9f36ed7
8 changed files with 417 additions and 92 deletions

View File

@ -14,11 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
from functools import reduce
import math
import copy
import random
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, ImageCaption
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -27,23 +28,19 @@ class DataLoaderMultiAspect():
"""
Data loader for multi-aspect-ratio training and bucketing
image_train_items: list of `lImageTrainItem` objects
image_train_items: list of `ImageTrainItem` objects
seed: random seed
batch_size: number of images per batch
"""
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
self.seed = seed
self.batch_size = batch_size
# Prepare data
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
# Initialize ratings
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for image in self.prepared_train_data:
self.rating_overall_sum += image.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer):
"""
@ -80,14 +77,17 @@ class DataLoaderMultiAspect():
del data_copy
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
"""
returns the current list of images including their captions in a randomized order,
sorted into buckets with same sized images
if dropout_fraction < 1.0, only a subset of the images will be returned
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
Returns the current list of `ImageTrainItem` in randomized order,
sorted into buckets with same sized images.
If dropout_fraction < 1.0, only a subset of the images will be returned.
If dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan.
:param dropout_fraction: must be between 0.0 and 1.0.
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
:return: Randomized list of `ImageTrainItem` objects
"""
self.seed += 1
@ -126,11 +126,11 @@ class DataLoaderMultiAspect():
buckets[bucket].extend(runt_bucket)
# flatten the buckets
image_caption_pairs = []
items: list[ImageTrainItem] = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
items.extend(buckets[bucket])
return image_caption_pairs
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
@ -168,3 +168,28 @@ class DataLoaderMultiAspect():
prepared_train_data.pop(pos)
return picked_images
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = items_copy[:item_count]
if remove_from_dataset:
self.delete_items(split_items)
return split_items
def delete_items(self, items: list[ImageTrainItem]):
for item in items:
for i, other_item in enumerate(self.prepared_train_data):
if other_item.pathname == item.pathname:
self.prepared_train_data.pop(i)
break
self.__update_rating_sums()
def __update_rating_sums(self):
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)

View File

@ -1,2 +0,0 @@
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
shared_dataloader = None

View File

@ -1,70 +0,0 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from torch.utils.data import Dataset
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
import math
import ldm.data.dl_singleton as dls
from ldm.data.image_train_item import ImageTrainItem
class EDValidateBatch(Dataset):
def __init__(self,
data_root,
flip_p=0.0,
repeats=1,
debug_level=0,
batch_size=1,
set='val',
):
self.data_root = data_root
self.batch_size = batch_size
if not dls.shared_dataloader:
print("Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_train_items)
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
print()
print(f" ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} ")
print()
def __len__(self):
return self._length
def __getitem__(self, i):
idx = i % self.num_images
image_train_item = self.image_train_items[idx]
example = self.__get_image_for_trainer(image_train_item)
return example
@staticmethod
def __get_image_for_trainer(image_train_item: ImageTrainItem):
example = {}
image_train_tmp = image_train_item.hydrate()
example["image"] = image_train_tmp.image
example["caption"] = image_train_tmp.caption
return example

View File

@ -41,7 +41,8 @@ class EveryDreamBatch(Dataset):
retain_contrast=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5
rated_dataset_dropout_target=0.5,
name='train'
):
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
@ -57,11 +58,19 @@ class EveryDreamBatch(Dataset):
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
# First epoch always trains on all images
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
self.image_train_items = []
self.__update_image_train_items(1.0)
self.name = name
num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
self.__update_image_train_items(1.0)
return items
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
@ -69,8 +78,8 @@ class EveryDreamBatch(Dataset):
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
self.__update_image_train_items(dropout_fraction)
def __len__(self):
return len(self.image_train_items)
@ -130,3 +139,38 @@ class EveryDreamBatch(Dataset):
example["caption"] = image_train_tmp.caption
example["runt_size"] = image_train_tmp.runt_size
return example
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
items,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=collate_fn
)
return dataloader
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret

View File

@ -0,0 +1,170 @@
import json
import random
from typing import Callable, Any, Optional
from argparse import Namespace
import torch
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from utils.isolate_rng import isolate_rng
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
train_batch: EveryDreamBatch,
log_writer: SummaryWriter):
self.log_writer = log_writer
val_config = {}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
val_config = json.load(f)
do_validation = val_config.get('validate_training', False)
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
self.val_data_root = val_config.get('val_data_root', None)
val_split_proportion = val_config.get('val_split_proportion', 0.15)
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
self.every_n_epochs = val_config.get('every_n_epochs', 1)
self.seed = val_config.get('seed', 555)
with isolate_rng():
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
split_proportion=val_split_proportion,
val_data_root=self.val_data_root,
train_batch=train_batch)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
split_proportion=stabilize_split_proportion,
name='train-stabilizer',
enforce_split=False) if stabilize_training_loss else None
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
# ok to override seed here because we are in a `with isolate_rng():` block
torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
# del timesteps, encoder_hidden_states, noisy_latents
# with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
loss_step = loss.detach().item()
loss_validation_epoch.append(loss_step)
steps_pbar.update(1)
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_validation_dataloader(self,
validation_split_mode: str,
split_proportion: float,
val_data_root: Optional[str],
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
if validation_split_mode == 'none':
return None
elif validation_split_mode == 'automatic':
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
elif validation_split_mode == 'manual':
if val_data_root is None:
raise ValueError("val_data_root is required for 'manual' validation split mode")
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
else:
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
def _build_dataloader_from_automatic_split(self,
train_batch: EveryDreamBatch,
split_proportion: float,
name: str,
enforce_split: bool=False) -> DataLoader:
"""
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
"""
with isolate_rng():
random.seed(self.seed)
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
if enforce_split:
print(
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
if len(train_batch) == 0:
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
val_batch = self._make_val_batch_with_train_batch_settings(
val_items,
train_batch,
name=name
)
return build_torch_dataloader(
items=val_batch,
batch_size=train_batch.batch_size,
)
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
return build_torch_dataloader(
items=val_batch,
batch_size=reference_train_batch.batch_size
)
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
batch_size = reference_train_batch.batch_size
seed = reference_train_batch.seed
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=seed,
)
image_train_items = resolver.resolve(data_root, args)
data_loader = DataLoaderMultiAspect(
image_train_items,
batch_size=batch_size,
seed=seed,
)
val_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
batch_size=batch_size,
conditional_dropout=0,
tokenizer=reference_train_batch.tokenizer,
seed=seed,
name=name,
)
return val_batch

73
utils/isolate_rng.py Normal file
View File

@ -0,0 +1,73 @@
# copy/pasted from pytorch lightning
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
# and
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator, Dict, Any
import torch
import numpy as np
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
states = {
"torch": torch.get_rng_state(),
"numpy": np.random.get_state(),
"python": python_get_rng_state(),
}
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all()
return states
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
process."""
torch.set_rng_state(rng_state_dict["torch"])
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
@contextmanager
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
"""A context manager that resets the global random state on exit to what it was before entering.
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
Args:
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
prohibited.
Example:
>>> import torch
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
<torch._C.Generator object at ...>
>>> with isolate_rng():
... [torch.rand(1) for _ in range(3)]
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
>>> torch.rand(1)
tensor([0.7576])
"""
states = _collect_rng_states(include_cuda)
yield
_set_rng_states(states)

65
utils/split_dataset.py Normal file
View File

@ -0,0 +1,65 @@
import argparse
import math
import os.path
import random
import shutil
from typing import Optional
from tqdm.auto import tqdm
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
for directory, _, filenames in os.walk(root_dir):
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
for image_filename in image_filenames:
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
image_path = os.path.join(directory+image_filename)
caption_path = os.path.join(directory+caption_filename)
yield (image_path, caption_path if os.path.exists(caption_path) else None)
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
image_path = image_caption_pair[0]
caption_path = image_caption_pair[1]
# make target folder if necessary
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root))
target_folder = os.path.join(target_root, relative_folder)
os.makedirs(target_folder, exist_ok=True)
# copy files
shutil.copy2(image_path, os.path.join(target_folder, os.path.basename(image_path)))
if caption_path is not None:
shutil.copy2(caption_path, os.path.join(target_folder, os.path.basename(caption_path)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('source_root', type=str, help='Source root folder containing images')
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
args = parser.parse_args()
images = gather_captioned_images(args.source_root)
print(f"Found {len(images)} captioned images in {args.source_root}")
val_split_count = math.ceil(len(images) * args.split_proportion)
if val_split_count == 0:
raise ValueError(f"No images in validation split with source count {len(images)} and split proportion {args.split_proportion}")
random.seed(args.seed)
random.shuffle(images)
val_split = images[0:val_split_count]
train_split = images[val_split_count:]
print(f"Split to 'train' set with {len(train_split)} images and 'val' set with {len(val_split)}")
print(f"Copying 'val' set to {args.val_output_folder}...")
for v in tqdm(val_split):
copy_captioned_image(v, args.source_root, args.val_output_folder)
print(f"Copying 'train' set to {args.train_output_folder}...")
for v in tqdm(train_split):
copy_captioned_image(v, args.source_root, args.train_output_folder)
print("Done.")

20
validation_default.json Normal file
View File

@ -0,0 +1,20 @@
{
"documentation": {
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"stabilize_training_loss": "If true, stabilise the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes."
},
"validate_training": true,
"val_split_mode": "automatic",
"val_data_root": null,
"val_split_proportion": 0.15,
"stabilize_training_loss": true,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,
"seed": 555
}