2023-02-08 03:28:45 -07:00
|
|
|
import copy
|
2023-02-06 23:10:34 -07:00
|
|
|
import json
|
2023-02-08 03:28:45 -07:00
|
|
|
import logging
|
2023-02-07 09:32:54 -07:00
|
|
|
import math
|
2023-02-06 23:10:34 -07:00
|
|
|
import random
|
2023-02-08 03:28:45 -07:00
|
|
|
from typing import Callable, Any, Optional, Generator
|
2023-02-06 23:10:34 -07:00
|
|
|
from argparse import Namespace
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from colorama import Fore, Style
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
|
|
|
from data.data_loader import DataLoaderMultiAspect
|
|
|
|
from data import resolver
|
|
|
|
from data import aspects
|
2023-02-07 09:32:54 -07:00
|
|
|
from data.image_train_item import ImageTrainItem
|
2023-02-06 23:10:34 -07:00
|
|
|
from utils.isolate_rng import isolate_rng
|
|
|
|
|
|
|
|
|
2023-02-07 09:32:54 -07:00
|
|
|
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
|
|
|
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
|
|
|
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
|
|
|
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
|
|
|
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
|
|
|
random.shuffle(items_copy)
|
|
|
|
split_items = list(items_copy[:split_item_count])
|
|
|
|
remaining_items = list(items_copy[split_item_count:])
|
|
|
|
return split_items, remaining_items
|
|
|
|
|
2023-02-08 03:28:45 -07:00
|
|
|
def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]:
|
|
|
|
for i in items:
|
|
|
|
yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1)
|
2023-02-07 09:32:54 -07:00
|
|
|
|
2023-02-06 23:10:34 -07:00
|
|
|
class EveryDreamValidator:
|
|
|
|
def __init__(self,
|
|
|
|
val_config_path: Optional[str],
|
2023-02-07 09:32:54 -07:00
|
|
|
default_batch_size: int,
|
2023-02-08 03:28:45 -07:00
|
|
|
resolution: int,
|
2023-02-06 23:10:34 -07:00
|
|
|
log_writer: SummaryWriter):
|
2023-02-07 09:32:54 -07:00
|
|
|
self.val_dataloader = None
|
|
|
|
self.train_overlapping_dataloader = None
|
|
|
|
|
2023-02-06 23:10:34 -07:00
|
|
|
self.log_writer = log_writer
|
2023-02-08 03:28:45 -07:00
|
|
|
self.resolution = resolution
|
2023-02-06 23:10:34 -07:00
|
|
|
|
2023-02-07 09:52:23 -07:00
|
|
|
self.config = {
|
|
|
|
'batch_size': default_batch_size,
|
|
|
|
'every_n_epochs': 1,
|
2023-02-07 10:18:21 -07:00
|
|
|
'seed': 555,
|
|
|
|
|
2023-02-08 03:28:45 -07:00
|
|
|
'validate_training': True,
|
2023-02-07 10:18:21 -07:00
|
|
|
'val_split_mode': 'automatic',
|
|
|
|
'val_split_proportion': 0.15,
|
|
|
|
|
|
|
|
'stabilize_training_loss': False,
|
|
|
|
'stabilize_split_proportion': 0.15
|
2023-02-07 09:52:23 -07:00
|
|
|
}
|
2023-02-06 23:10:34 -07:00
|
|
|
if val_config_path is not None:
|
|
|
|
with open(val_config_path, 'rt') as f:
|
2023-02-07 09:52:23 -07:00
|
|
|
self.config.update(json.load(f))
|
2023-02-06 23:10:34 -07:00
|
|
|
|
2023-02-07 09:52:23 -07:00
|
|
|
@property
|
|
|
|
def batch_size(self):
|
|
|
|
return self.config['batch_size']
|
|
|
|
|
|
|
|
@property
|
|
|
|
def every_n_epochs(self):
|
|
|
|
return self.config['every_n_epochs']
|
|
|
|
|
|
|
|
@property
|
|
|
|
def seed(self):
|
|
|
|
return self.config['seed']
|
2023-02-06 23:10:34 -07:00
|
|
|
|
2023-02-07 09:32:54 -07:00
|
|
|
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
|
|
|
"""
|
|
|
|
Build the validation splits as requested by the config passed at init.
|
|
|
|
This may steal some items from `train_items`.
|
|
|
|
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
|
|
|
|
Otherwise, the returned `list` is identical to the passed-in `train_items`.
|
|
|
|
"""
|
2023-02-06 23:10:34 -07:00
|
|
|
with isolate_rng():
|
2023-02-07 09:54:00 -07:00
|
|
|
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
|
2023-02-06 23:10:34 -07:00
|
|
|
# order is important - if we're removing images from train, this needs to happen before making
|
|
|
|
# the overlapping dataloader
|
2023-02-07 10:21:05 -07:00
|
|
|
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
|
2023-02-07 09:54:00 -07:00
|
|
|
remaining_train_items, tokenizer)
|
2023-02-07 09:32:54 -07:00
|
|
|
return remaining_train_items
|
2023-02-06 23:10:34 -07:00
|
|
|
|
|
|
|
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
|
|
|
get_model_prediction_and_target_callable: Callable[
|
|
|
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
|
|
|
if (epoch % self.every_n_epochs) == 0:
|
|
|
|
if self.train_overlapping_dataloader is not None:
|
2023-02-07 09:32:54 -07:00
|
|
|
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
|
|
|
|
get_model_prediction_and_target_callable)
|
2023-02-06 23:10:34 -07:00
|
|
|
if self.val_dataloader is not None:
|
|
|
|
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
|
|
|
|
|
|
|
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
2023-02-07 09:32:54 -07:00
|
|
|
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
2023-02-06 23:10:34 -07:00
|
|
|
with torch.no_grad(), isolate_rng():
|
|
|
|
loss_validation_epoch = []
|
|
|
|
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
|
|
|
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
|
|
|
|
|
|
|
for step, batch in enumerate(dataloader):
|
|
|
|
# ok to override seed here because we are in a `with isolate_rng():` block
|
|
|
|
torch.manual_seed(self.seed + step)
|
|
|
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
|
|
|
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
|
del target, model_pred
|
|
|
|
|
|
|
|
loss_step = loss.detach().item()
|
|
|
|
loss_validation_epoch.append(loss_step)
|
|
|
|
|
|
|
|
steps_pbar.update(1)
|
|
|
|
|
|
|
|
steps_pbar.close()
|
|
|
|
|
|
|
|
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
|
|
|
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
|
|
|
|
2023-02-07 09:54:00 -07:00
|
|
|
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
|
2023-02-07 09:32:54 -07:00
|
|
|
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
|
2023-02-08 03:28:45 -07:00
|
|
|
val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None
|
2023-02-07 10:18:21 -07:00
|
|
|
val_split_proportion = self.config['val_split_proportion']
|
2023-02-07 09:32:54 -07:00
|
|
|
remaining_train_items = image_train_items
|
2023-02-08 03:28:45 -07:00
|
|
|
if val_split_mode is None or val_split_mode == 'none':
|
2023-02-07 09:32:54 -07:00
|
|
|
return None, image_train_items
|
|
|
|
elif val_split_mode == 'automatic':
|
|
|
|
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
|
2023-02-08 03:28:45 -07:00
|
|
|
val_items = list(disable_multiplier_and_flip(val_items))
|
|
|
|
logging.info(f" * Removed {len(val_items)} images from the training set to use for validation")
|
2023-02-07 09:32:54 -07:00
|
|
|
elif val_split_mode == 'manual':
|
|
|
|
args = Namespace(
|
2023-02-08 03:28:45 -07:00
|
|
|
aspects=aspects.get_aspect_buckets(self.resolution),
|
2023-02-07 09:32:54 -07:00
|
|
|
flip_p=0.0,
|
|
|
|
seed=self.seed,
|
2023-02-06 23:10:34 -07:00
|
|
|
)
|
2023-02-07 09:52:23 -07:00
|
|
|
val_data_root = self.config['val_data_root']
|
|
|
|
val_items = resolver.resolve_root(val_data_root, args)
|
2023-02-08 03:28:45 -07:00
|
|
|
logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}")
|
2023-02-07 09:32:54 -07:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
|
|
|
|
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
|
|
|
|
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
|
|
|
|
return val_dataloader, remaining_train_items
|
|
|
|
|
2023-02-07 10:21:05 -07:00
|
|
|
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
|
2023-02-07 09:32:54 -07:00
|
|
|
-> Optional[torch.utils.data.DataLoader]:
|
2023-02-07 10:18:21 -07:00
|
|
|
stabilize_training_loss = self.config['stabilize_training_loss']
|
2023-02-07 09:32:54 -07:00
|
|
|
if not stabilize_training_loss:
|
|
|
|
return None
|
2023-02-06 23:10:34 -07:00
|
|
|
|
2023-02-07 10:18:21 -07:00
|
|
|
stabilize_split_proportion = self.config['stabilize_split_proportion']
|
2023-02-07 10:21:05 -07:00
|
|
|
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
|
2023-02-08 03:28:45 -07:00
|
|
|
stabilize_items = list(disable_multiplier_and_flip(stabilize_items))
|
2023-02-07 10:21:05 -07:00
|
|
|
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
|
2023-02-07 09:32:54 -07:00
|
|
|
name='stabilize-train')
|
|
|
|
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
|
|
|
|
return stabilize_dataloader
|
2023-02-06 23:10:34 -07:00
|
|
|
|
2023-02-07 09:32:54 -07:00
|
|
|
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
|
|
|
|
batch_size = self.batch_size
|
|
|
|
seed = self.seed
|
2023-02-06 23:10:34 -07:00
|
|
|
data_loader = DataLoaderMultiAspect(
|
2023-02-07 09:32:54 -07:00
|
|
|
items,
|
2023-02-06 23:10:34 -07:00
|
|
|
batch_size=batch_size,
|
|
|
|
seed=seed,
|
|
|
|
)
|
2023-02-07 09:32:54 -07:00
|
|
|
ed_batch = EveryDreamBatch(
|
2023-02-06 23:10:34 -07:00
|
|
|
data_loader=data_loader,
|
|
|
|
debug_level=1,
|
|
|
|
conditional_dropout=0,
|
2023-02-07 09:32:54 -07:00
|
|
|
tokenizer=tokenizer,
|
2023-02-06 23:10:34 -07:00
|
|
|
seed=seed,
|
|
|
|
name=name,
|
|
|
|
)
|
2023-02-07 09:32:54 -07:00
|
|
|
return ed_batch
|