EveryDream2trainer/data/every_dream_validation.py

174 lines
7.6 KiB
Python

import json
import math
import random
from typing import Callable, Any, Optional
from argparse import Namespace
import torch
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
@property
def batch_size(self):
return self.config['batch_size']
@property
def every_n_epochs(self):
return self.config['every_n_epochs']
@property
def seed(self):
return self.config['seed']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng():
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
remaining_train_items, tokenizer)
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
# ok to override seed here because we are in a `with isolate_rng():` block
torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
loss_step = loss.detach().item()
loss_validation_epoch.append(loss_step)
steps_pbar.update(1)
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode']
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(
items,
batch_size=batch_size,
seed=seed,
)
ed_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
conditional_dropout=0,
tokenizer=tokenizer,
seed=seed,
name=name,
)
return ed_batch