EveryDream2trainer/data/every_dream_validation.py

224 lines
10 KiB
Python

import json
import logging
import math
import random
from typing import Callable, Any, Optional, Generator
from argparse import Namespace
import torch
import numpy as np
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = max(1, math.ceil(split_proportion * len(items)))
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
for i in items:
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
resolution: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.resolution = resolution
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'validate_training': True,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
self.train_overlapping_dataloader_loss_offset = None
self.val_loss_offset = None
self.loss_val_history = []
self.val_loss_window_size = 4 # todo: arg for this?
@property
def batch_size(self):
return self.config['batch_size']
@property
def every_n_epochs(self):
return self.config['every_n_epochs']
@property
def seed(self):
return self.config['seed']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng():
random.seed(self.seed)
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
remaining_train_items, tokenizer)
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
mean_loss = self._calculate_validation_loss('stabilize-train',
self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.train_overlapping_dataloader_loss_offset is None:
self.train_overlapping_dataloader_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/stabilize-train",
scalar_value=self.train_overlapping_dataloader_loss_offset + mean_loss,
global_step=global_step)
if self.val_dataloader is not None:
mean_loss = self._calculate_validation_loss('val',
self.val_dataloader,
get_model_prediction_and_target_callable)
if self.val_loss_offset is None:
self.val_loss_offset = -mean_loss
self.log_writer.add_scalar(tag=f"loss/val",
scalar_value=self.val_loss_offset + mean_loss,
global_step=global_step)
self.loss_val_history.append(mean_loss)
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
if np.average(dy) > 0:
logging.warning(f"Validation loss shows diverging")
# todo: signal stop?
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
with torch.no_grad(), isolate_rng():
# ok to override seed here because we are in a `with isolate_rng():` block
random.seed(self.seed)
torch.manual_seed(self.seed)
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
loss_step = loss.detach().item()
loss_validation_epoch.append(loss_step)
steps_pbar.update(1)
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
return loss_validation_local
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode is None or val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
val_items = list(disable_multiplier_and_flip(val_items))
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
elif val_split_mode == 'manual':
val_data_root = self.config.get('val_data_root', None)
if val_data_root is None:
raise ValueError("Manual validation split requested but `val_data_root` is not defined in validation config")
val_items = self._load_manual_val_split(val_data_root)
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
def _load_manual_val_split(self, val_data_root: str):
args = Namespace(
aspects=aspects.get_aspect_buckets(self.resolution),
flip_p=0.0,
seed=self.seed,
)
val_items = resolver.resolve_root(val_data_root, args)
val_items.sort(key=lambda i: i.pathname)
random.shuffle(val_items)
return val_items
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(
items,
batch_size=batch_size,
seed=seed,
)
ed_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
conditional_dropout=0,
tokenizer=tokenizer,
seed=seed,
name=name,
crop_jitter=0
)
return ed_batch