278 lines
12 KiB
Python
278 lines
12 KiB
Python
import json
|
|
import logging
|
|
import math
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Any, Optional, Generator
|
|
from argparse import Namespace
|
|
|
|
import torch
|
|
import numpy as np
|
|
from colorama import Fore, Style
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from tqdm.auto import tqdm
|
|
|
|
from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
|
from data.data_loader import DataLoaderMultiAspect
|
|
from data import resolver
|
|
from data import aspects
|
|
from data.image_train_item import ImageTrainItem
|
|
from utils.isolate_rng import isolate_rng
|
|
|
|
from colorama import Fore, Style
|
|
|
|
|
|
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
|
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
|
split_item_count = max(1, math.ceil(split_proportion * len(items)))
|
|
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
|
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
|
random.shuffle(items_copy)
|
|
split_items = list(items_copy[:split_item_count])
|
|
remaining_items = list(items_copy[split_item_count:])
|
|
return split_items, remaining_items
|
|
|
|
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
|
|
for i in items:
|
|
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
|
|
|
|
|
@dataclass
|
|
class ValidationDataset:
|
|
name: str
|
|
dataloader: torch.utils.data.DataLoader
|
|
loss_history: list[float] = field(default_factory=list)
|
|
val_loss_window_size: Optional[int] = 5 # todo: arg for this?
|
|
|
|
def track_loss_trend(self, mean_loss: float):
|
|
if self.val_loss_window_size is None:
|
|
return
|
|
self.loss_history.append(mean_loss)
|
|
|
|
if len(self.loss_history) > ((self.val_loss_window_size * 2) + 1):
|
|
dy = np.diff(self.loss_history[-self.val_loss_window_size:])
|
|
if np.average(dy) > 0:
|
|
logging.warning(f"Validation loss for {self.name} shows diverging. Check your loss/{self.name} graph.")
|
|
|
|
|
|
class EveryDreamValidator:
|
|
def __init__(self,
|
|
val_config_path: Optional[str],
|
|
default_batch_size: int,
|
|
resolution: int,
|
|
log_writer: SummaryWriter,
|
|
):
|
|
self.validation_datasets = []
|
|
self.resolution = resolution
|
|
self.log_writer = log_writer
|
|
|
|
self.config = {
|
|
'batch_size': default_batch_size,
|
|
'every_n_epochs': 1,
|
|
'seed': 555,
|
|
|
|
'validate_training': True,
|
|
'val_split_mode': 'automatic',
|
|
'auto_split_proportion': 0.15,
|
|
|
|
'stabilize_training_loss': False,
|
|
'stabilize_split_proportion': 0.15,
|
|
|
|
'use_relative_loss': False,
|
|
|
|
'extra_manual_datasets': {
|
|
# name: path pairs
|
|
# eg "santa suit": "/path/to/captioned_santa_suit_images", will be logged to tensorboard as "loss/santa suit"
|
|
}
|
|
}
|
|
if val_config_path is not None:
|
|
with open(val_config_path, 'rt') as f:
|
|
self.config.update(json.load(f))
|
|
|
|
if 'val_data_root' in self.config:
|
|
logging.warning(f" * {Fore.YELLOW}using old name 'val_data_root' for 'manual_data_root' - please "
|
|
f"update your validation config json{Style.RESET_ALL}")
|
|
self.config.update({'manual_data_root': self.config['val_data_root']})
|
|
|
|
if self.config.get('val_split_mode') == 'manual':
|
|
if 'manual_data_root' not in self.config:
|
|
raise ValueError("Error in validation config .json: 'manual' validation is missing 'manual_data_root'")
|
|
self.config['extra_manual_datasets'].update({'val': self.config['manual_data_root']})
|
|
|
|
if 'val_split_proportion' in self.config:
|
|
logging.warning(f" * {Fore.YELLOW}using old name 'val_split_proportion' for 'auto_split_proportion' - please "
|
|
f"update your validation config json{Style.RESET_ALL}")
|
|
self.config.update({'auto_split_proportion': self.config['val_split_proportion']})
|
|
|
|
|
|
|
|
@property
|
|
def batch_size(self):
|
|
return self.config['batch_size']
|
|
|
|
@property
|
|
def every_n_epochs(self):
|
|
return self.config['every_n_epochs']
|
|
|
|
@property
|
|
def seed(self):
|
|
return self.config['seed']
|
|
|
|
@property
|
|
def use_relative_loss(self):
|
|
return self.config['use_relative_loss']
|
|
|
|
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
|
"""
|
|
Build the validation splits as requested by the config passed at init.
|
|
This may steal some items from `train_items`.
|
|
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
|
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
|
"""
|
|
with isolate_rng():
|
|
random.seed(self.seed)
|
|
|
|
auto_dataset, remaining_train_items = self._build_automatic_validation_dataset_if_required(train_items, tokenizer)
|
|
# order is important - if we're removing images from train, this needs to happen before making
|
|
# the overlapping dataloader
|
|
train_overlapping_dataset = self._build_train_stabilizer_dataloader_if_required(
|
|
remaining_train_items, tokenizer)
|
|
|
|
if auto_dataset is not None:
|
|
self.validation_datasets.append(auto_dataset)
|
|
if train_overlapping_dataset is not None:
|
|
self.validation_datasets.append(train_overlapping_dataset)
|
|
manual_splits = self._build_manual_validation_datasets(tokenizer)
|
|
self.validation_datasets.extend(manual_splits)
|
|
|
|
return remaining_train_items
|
|
|
|
def get_validation_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
|
if self.every_n_epochs >= 1:
|
|
if ((epoch+1) % self.every_n_epochs) == 0:
|
|
# last step only
|
|
return [epoch_length_steps-1]
|
|
else:
|
|
return []
|
|
else:
|
|
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
|
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
|
# validation happens after training:
|
|
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
|
validate_every_n_steps = epoch_length_steps / num_divisions
|
|
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
|
|
|
def do_validation(self, global_step: int,
|
|
get_model_prediction_and_target_callable: Callable[
|
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
|
for i, dataset in enumerate(self.validation_datasets):
|
|
mean_loss = self._calculate_validation_loss(dataset.name,
|
|
dataset.dataloader,
|
|
get_model_prediction_and_target_callable)
|
|
self.log_writer.add_scalar(tag=f"loss/{dataset.name}",
|
|
scalar_value=mean_loss,
|
|
global_step=global_step)
|
|
dataset.track_loss_trend(mean_loss)
|
|
|
|
|
|
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
|
with torch.no_grad(), isolate_rng():
|
|
# ok to override seed here because we are in a `with isolate_rng():` block
|
|
random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
|
|
loss_validation_epoch = []
|
|
steps_pbar = tqdm(range(len(dataloader)), position=1, leave=False)
|
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
|
|
|
for step, batch in enumerate(dataloader):
|
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
del target, model_pred
|
|
|
|
loss_step = loss.detach().item()
|
|
loss_validation_epoch.append(loss_step)
|
|
|
|
steps_pbar.update(1)
|
|
|
|
steps_pbar.close()
|
|
|
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
|
return loss_validation_local
|
|
|
|
|
|
def _build_automatic_validation_dataset_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
|
-> tuple[Optional[ValidationDataset], list[ImageTrainItem]]:
|
|
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
|
if val_split_mode is None or val_split_mode == 'none' or val_split_mode == 'manual':
|
|
# manual is handled by _build_manual_validation_datasets
|
|
return None, image_train_items
|
|
elif val_split_mode == 'automatic':
|
|
auto_split_proportion = self.config['auto_split_proportion']
|
|
val_items, remaining_train_items = get_random_split(image_train_items, auto_split_proportion, batch_size=self.batch_size)
|
|
val_items = list(disable_multiplier_and_flip(val_items))
|
|
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
|
val_ed_batch = self._build_ed_batch(val_items, tokenizer=tokenizer, name='val')
|
|
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
|
return ValidationDataset(name='val', dataloader=val_dataloader), remaining_train_items
|
|
else:
|
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
|
|
|
def _build_manual_validation_datasets(self, tokenizer) -> list[ValidationDataset]:
|
|
datasets = []
|
|
for name, root in self.config.get('extra_manual_datasets', {}).items():
|
|
items = self._load_manual_val_split(root)
|
|
logging.info(f" * Loaded {len(items)} validation images for validation set '{name}' from {root}")
|
|
ed_batch = self._build_ed_batch(items, tokenizer=tokenizer, name=name)
|
|
dataloader = build_torch_dataloader(ed_batch, batch_size=self.batch_size)
|
|
datasets.append(ValidationDataset(name=name, dataloader=dataloader))
|
|
return datasets
|
|
|
|
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
|
-> Optional[ValidationDataset]:
|
|
stabilize_training_loss = self.config['stabilize_training_loss']
|
|
if not stabilize_training_loss:
|
|
return None
|
|
|
|
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
|
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
|
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
|
stabilize_ed_batch = self._build_ed_batch(stabilize_items, tokenizer=tokenizer, name='stabilize-train')
|
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
|
return ValidationDataset(name='stabilize-train', dataloader=stabilize_dataloader, val_loss_window_size=None)
|
|
|
|
def _load_manual_val_split(self, val_data_root: str):
|
|
args = Namespace(
|
|
aspects=aspects.get_aspect_buckets(self.resolution),
|
|
flip_p=0.0,
|
|
seed=self.seed,
|
|
)
|
|
val_items = resolver.resolve_root(val_data_root, args)
|
|
val_items.sort(key=lambda i: i.pathname)
|
|
random.shuffle(val_items)
|
|
return val_items
|
|
|
|
def _build_ed_batch(self, items: list[ImageTrainItem], tokenizer, name='val'):
|
|
batch_size = self.batch_size
|
|
seed = self.seed
|
|
data_loader = DataLoaderMultiAspect(
|
|
items,
|
|
batch_size=batch_size,
|
|
seed=seed,
|
|
)
|
|
ed_batch = EveryDreamBatch(
|
|
data_loader=data_loader,
|
|
debug_level=1,
|
|
conditional_dropout=0,
|
|
tokenizer=tokenizer,
|
|
seed=seed,
|
|
name=name,
|
|
crop_jitter=0
|
|
)
|
|
return ed_batch
|