EveryDream2trainer/data/every_dream_validation.py

278 lines
12 KiB
Python

import json
import logging
import math
import random
from dataclasses import dataclass, field
from typing import Callable, Any, Optional, Generator
from argparse import Namespace
import torch
import numpy as np
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
from colorama import Fore, Style
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = max(1, math.ceil(split_proportion * len(items)))
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
@dataclass
class ValidationDataset:
name: str
dataloader: torch.utils.data.DataLoader
loss_history: list[float] = field(default_factory=list)
val_loss_window_size: Optional[int] = 5 # todo: arg for this?
def track_loss_trend(self, mean_loss: float):
if self.val_loss_window_size is None:
return
self.loss_history.append(mean_loss)
if len(self.loss_history) > ((self.val_loss_window_size * 2) + 1):
dy = np.diff(self.loss_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss for {self.name} shows diverging. Check your loss/{self.name} graph.")
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
resolution: int,
log_writer: SummaryWriter,
):
self.validation_datasets = []
self.resolution = resolution
self.log_writer = log_writer
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'validate_training': True,
'val_split_mode': 'automatic',
'auto_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15,
'use_relative_loss': False,
'extra_manual_datasets': {
# name: path pairs
# eg "santa suit": "/path/to/captioned_santa_suit_images", will be logged to tensorboard as "loss/santa suit"
}
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
if 'val_data_root' in self.config:
logging.warning(f" * {Fore.YELLOW}using old name 'val_data_root' for 'manual_data_root' - please "
f"update your validation config json{Style.RESET_ALL}")
self.config.update({'manual_data_root': self.config['val_data_root']})
if self.config.get('val_split_mode') == 'manual':
if 'manual_data_root' not in self.config:
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
if 'val_split_proportion' in self.config:
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
f"update your validation config json{Style.RESET_ALL}")
self.config.update({'auto_split_proportion': self.config['val_split_proportion']})
@property
def batch_size(self):
return self.config['batch_size']
@property
def every_n_epochs(self):
return self.config['every_n_epochs']
@property
def seed(self):
return self.config['seed']
@property
def use_relative_loss(self):
return self.config['use_relative_loss']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng():
random.seed(self.seed)
auto_dataset, remaining_train_items = self._build_automatic_validation_dataset_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
train_overlapping_dataset = self._build_train_stabilizer_dataloader_if_required(
remaining_train_items, tokenizer)
if auto_dataset is not None:
self.validation_datasets.append(auto_dataset)
if train_overlapping_dataset is not None:
self.validation_datasets.append(train_overlapping_dataset)
manual_splits = self._build_manual_validation_datasets(tokenizer)
self.validation_datasets.extend(manual_splits)
return remaining_train_items
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
else:
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
validate_every_n_steps = epoch_length_steps / num_divisions
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
def do_validation(self, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
for i, dataset in enumerate(self.validation_datasets):
mean_loss = self._calculate_validation_loss(dataset.name,
dataset.dataloader,
get_model_prediction_and_target_callable)
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
scalar_value=mean_loss,
global_step=global_step)
dataset.track_loss_trend(mean_loss)
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
with torch.no_grad(), isolate_rng():
# ok to override seed here because we are in a `with isolate_rng():` block
random.seed(self.seed)
torch.manual_seed(self.seed)
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
loss_step = loss.detach().item()
loss_validation_epoch.append(loss_step)
steps_pbar.update(1)
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
return loss_validation_local
def _build_automatic_validation_dataset_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> tuple[Optional[ValidationDataset], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
if val_split_mode is None or val_split_mode == 'none' or val_split_mode == 'manual':
# manual is handled by _build_manual_validation_datasets
return None, image_train_items
elif val_split_mode == 'automatic':
auto_split_proportion = self.config['auto_split_proportion']
val_items, remaining_train_items = get_random_split(image_train_items, auto_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
val_ed_batch = self._build_ed_batch(val_items, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return ValidationDataset(name='val', dataloader=val_dataloader), remaining_train_items
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
def _build_manual_validation_datasets(self, tokenizer) -> list[ValidationDataset]:
datasets = []
for name, root in self.config.get('extra_manual_datasets', {}).items():
items = self._load_manual_val_split(root)
logging.info(f" * Loaded {len(items)} validation images for validation set '{name}' from {root}")
ed_batch = self._build_ed_batch(items, tokenizer=tokenizer, name=name)
dataloader = build_torch_dataloader(ed_batch, batch_size=self.batch_size)
datasets.append(ValidationDataset(name=name, dataloader=dataloader))
return datasets
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[ValidationDataset]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, tokenizer=tokenizer, name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return ValidationDataset(name='stabilize-train', dataloader=stabilize_dataloader, val_loss_window_size=None)
def _load_manual_val_split(self, val_data_root: str):
args = Namespace(
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_items = resolver.resolve_root(val_data_root, args)
val_items.sort(key=lambda i: i.pathname)
random.shuffle(val_items)
return val_items
def _build_ed_batch(self, items: list[ImageTrainItem], tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(
items,
batch_size=batch_size,
seed=seed,
)
ed_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
conditional_dropout=0,
tokenizer=tokenizer,
seed=seed,
name=name,
crop_jitter=0
)
return ed_batch