GH-36: Add support for validation split (WIP)
Co-authored-by: Damian Stewart <office@damianstewart.com>
This commit is contained in:
parent
85f19b9a2f
commit
41c9f36ed7
|
@ -14,11 +14,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
import bisect
|
||||
from functools import reduce
|
||||
import math
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import PIL
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
@ -27,23 +28,19 @@ class DataLoaderMultiAspect():
|
|||
"""
|
||||
Data loader for multi-aspect-ratio training and bucketing
|
||||
|
||||
image_train_items: list of `lImageTrainItem` objects
|
||||
image_train_items: list of `ImageTrainItem` objects
|
||||
seed: random seed
|
||||
batch_size: number of images per batch
|
||||
"""
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
# Prepare data
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
# Initialize ratings
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for image in self.prepared_train_data:
|
||||
self.rating_overall_sum += image.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
||||
self.__update_rating_sums()
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
"""
|
||||
|
@ -80,14 +77,17 @@ class DataLoaderMultiAspect():
|
|||
del data_copy
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||
"""
|
||||
returns the current list of images including their captions in a randomized order,
|
||||
sorted into buckets with same sized images
|
||||
if dropout_fraction < 1.0, only a subset of the images will be returned
|
||||
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
|
||||
Returns the current list of `ImageTrainItem` in randomized order,
|
||||
sorted into buckets with same sized images.
|
||||
|
||||
If dropout_fraction < 1.0, only a subset of the images will be returned.
|
||||
|
||||
If dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan.
|
||||
|
||||
:param dropout_fraction: must be between 0.0 and 1.0.
|
||||
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
|
||||
:return: Randomized list of `ImageTrainItem` objects
|
||||
"""
|
||||
|
||||
self.seed += 1
|
||||
|
@ -126,11 +126,11 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
image_caption_pairs = []
|
||||
items: list[ImageTrainItem] = []
|
||||
for bucket in buckets:
|
||||
image_caption_pairs.extend(buckets[bucket])
|
||||
items.extend(buckets[bucket])
|
||||
|
||||
return image_caption_pairs
|
||||
return items
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -168,3 +168,28 @@ class DataLoaderMultiAspect():
|
|||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
item_count = math.ceil(split_proportion * len(self.prepared_train_data) // self.batch_size) * self.batch_size
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(self.prepared_train_data, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
split_items = items_copy[:item_count]
|
||||
if remove_from_dataset:
|
||||
self.delete_items(split_items)
|
||||
return split_items
|
||||
|
||||
def delete_items(self, items: list[ImageTrainItem]):
|
||||
for item in items:
|
||||
for i, other_item in enumerate(self.prepared_train_data):
|
||||
if other_item.pathname == item.pathname:
|
||||
self.prepared_train_data.pop(i)
|
||||
break
|
||||
self.__update_rating_sums()
|
||||
|
||||
def __update_rating_sums(self):
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
|
@ -1,2 +0,0 @@
|
|||
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
|
||||
shared_dataloader = None
|
|
@ -1,70 +0,0 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
|
||||
import math
|
||||
import ldm.data.dl_singleton as dls
|
||||
from ldm.data.image_train_item import ImageTrainItem
|
||||
|
||||
class EDValidateBatch(Dataset):
|
||||
def __init__(self,
|
||||
data_root,
|
||||
flip_p=0.0,
|
||||
repeats=1,
|
||||
debug_level=0,
|
||||
batch_size=1,
|
||||
set='val',
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.batch_size = batch_size
|
||||
|
||||
if not dls.shared_dataloader:
|
||||
print("Creating new dataloader singleton")
|
||||
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p)
|
||||
|
||||
self.image_train_items = dls.shared_dataloader.get_all_images()
|
||||
|
||||
self.num_images = len(self.image_train_items)
|
||||
|
||||
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
|
||||
|
||||
print()
|
||||
print(f" ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} ")
|
||||
print()
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = i % self.num_images
|
||||
image_train_item = self.image_train_items[idx]
|
||||
|
||||
example = self.__get_image_for_trainer(image_train_item)
|
||||
return example
|
||||
|
||||
@staticmethod
|
||||
def __get_image_for_trainer(image_train_item: ImageTrainItem):
|
||||
example = {}
|
||||
|
||||
image_train_tmp = image_train_item.hydrate()
|
||||
|
||||
example["image"] = image_train_tmp.image
|
||||
example["caption"] = image_train_tmp.caption
|
||||
|
||||
return example
|
||||
|
|
@ -41,7 +41,8 @@ class EveryDreamBatch(Dataset):
|
|||
retain_contrast=False,
|
||||
shuffle_tags=False,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
):
|
||||
self.data_loader = data_loader
|
||||
self.batch_size = data_loader.batch_size
|
||||
|
@ -57,11 +58,19 @@ class EveryDreamBatch(Dataset):
|
|||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
# First epoch always trains on all images
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
|
||||
self.image_train_items = []
|
||||
self.__update_image_train_items(1.0)
|
||||
self.name = name
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
|
||||
|
||||
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
|
||||
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
|
||||
self.__update_image_train_items(1.0)
|
||||
return items
|
||||
|
||||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
||||
|
@ -69,8 +78,8 @@ class EveryDreamBatch(Dataset):
|
|||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
self.__update_image_train_items(dropout_fraction)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_train_items)
|
||||
|
@ -130,3 +139,38 @@ class EveryDreamBatch(Dataset):
|
|||
example["caption"] = image_train_tmp.caption
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
def build_torch_dataloader(items, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
items,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
|
@ -0,0 +1,170 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Callable, Any, Optional
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from colorama import Fore, Style
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.every_dream import build_torch_dataloader, EveryDreamBatch
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
from data import resolver
|
||||
from data import aspects
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
|
||||
class EveryDreamValidator:
|
||||
def __init__(self,
|
||||
val_config_path: Optional[str],
|
||||
train_batch: EveryDreamBatch,
|
||||
log_writer: SummaryWriter):
|
||||
self.log_writer = log_writer
|
||||
|
||||
val_config = {}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
val_config = json.load(f)
|
||||
|
||||
do_validation = val_config.get('validate_training', False)
|
||||
val_split_mode = val_config.get('val_split_mode', 'automatic') if do_validation else 'none'
|
||||
self.val_data_root = val_config.get('val_data_root', None)
|
||||
val_split_proportion = val_config.get('val_split_proportion', 0.15)
|
||||
|
||||
stabilize_training_loss = val_config.get('stabilize_training_loss', False)
|
||||
stabilize_split_proportion = val_config.get('stabilize_split_proportion', 0.15)
|
||||
|
||||
self.every_n_epochs = val_config.get('every_n_epochs', 1)
|
||||
self.seed = val_config.get('seed', 555)
|
||||
|
||||
with isolate_rng():
|
||||
self.val_dataloader = self._build_validation_dataloader(val_split_mode,
|
||||
split_proportion=val_split_proportion,
|
||||
val_data_root=self.val_data_root,
|
||||
train_batch=train_batch)
|
||||
# order is important - if we're removing images from train, this needs to happen before making
|
||||
# the overlapping dataloader
|
||||
self.train_overlapping_dataloader = self._build_dataloader_from_automatic_split(train_batch,
|
||||
split_proportion=stabilize_split_proportion,
|
||||
name='train-stabilizer',
|
||||
enforce_split=False) if stabilize_training_loss else None
|
||||
|
||||
|
||||
def do_validation_if_appropriate(self, epoch: int, global_step: int,
|
||||
get_model_prediction_and_target_callable: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
if (epoch % self.every_n_epochs) == 0:
|
||||
if self.train_overlapping_dataloader is not None:
|
||||
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader, get_model_prediction_and_target_callable)
|
||||
if self.val_dataloader is not None:
|
||||
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
|
||||
|
||||
|
||||
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
|
||||
with torch.no_grad(), isolate_rng():
|
||||
loss_validation_epoch = []
|
||||
steps_pbar = tqdm(range(len(dataloader)), position=1)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
# ok to override seed here because we are in a `with isolate_rng():` block
|
||||
torch.manual_seed(self.seed + step)
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
|
||||
|
||||
# del timesteps, encoder_hidden_states, noisy_latents
|
||||
# with autocast(enabled=args.amp):
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
del target, model_pred
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
loss_validation_epoch.append(loss_step)
|
||||
|
||||
steps_pbar.update(1)
|
||||
|
||||
steps_pbar.close()
|
||||
|
||||
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
|
||||
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
|
||||
|
||||
|
||||
def _build_validation_dataloader(self,
|
||||
validation_split_mode: str,
|
||||
split_proportion: float,
|
||||
val_data_root: Optional[str],
|
||||
train_batch: EveryDreamBatch) -> Optional[DataLoader]:
|
||||
if validation_split_mode == 'none':
|
||||
return None
|
||||
elif validation_split_mode == 'automatic':
|
||||
return self._build_dataloader_from_automatic_split(train_batch, split_proportion, name='val', enforce_split=True)
|
||||
elif validation_split_mode == 'manual':
|
||||
if val_data_root is None:
|
||||
raise ValueError("val_data_root is required for 'manual' validation split mode")
|
||||
return self._build_dataloader_from_custom_split(self.val_data_root, reference_train_batch=train_batch)
|
||||
else:
|
||||
raise ValueError(f"unhandled validation split mode '{validation_split_mode}'")
|
||||
|
||||
|
||||
def _build_dataloader_from_automatic_split(self,
|
||||
train_batch: EveryDreamBatch,
|
||||
split_proportion: float,
|
||||
name: str,
|
||||
enforce_split: bool=False) -> DataLoader:
|
||||
"""
|
||||
Build a validation dataloader by copying data from the given `train_batch`. If `enforce_split` is `True`, remove
|
||||
the copied items from train_batch so that there is no overlap between `train_batch` and the new dataloader.
|
||||
"""
|
||||
with isolate_rng():
|
||||
random.seed(self.seed)
|
||||
val_items = train_batch.get_random_split(split_proportion, remove_from_dataset=enforce_split)
|
||||
if enforce_split:
|
||||
print(
|
||||
f" * Removed {len(val_items)} items for validation split from '{train_batch.name}' - {round(len(train_batch)/train_batch.batch_size)} batches are left")
|
||||
if len(train_batch) == 0:
|
||||
raise ValueError(f"Validation split used up all of the training data. Try a lower split proportion than {split_proportion}")
|
||||
val_batch = self._make_val_batch_with_train_batch_settings(
|
||||
val_items,
|
||||
train_batch,
|
||||
name=name
|
||||
)
|
||||
return build_torch_dataloader(
|
||||
items=val_batch,
|
||||
batch_size=train_batch.batch_size,
|
||||
)
|
||||
|
||||
|
||||
def _build_dataloader_from_custom_split(self, data_root: str, reference_train_batch: EveryDreamBatch) -> DataLoader:
|
||||
val_batch = self._make_val_batch_with_train_batch_settings(data_root, reference_train_batch)
|
||||
return build_torch_dataloader(
|
||||
items=val_batch,
|
||||
batch_size=reference_train_batch.batch_size
|
||||
)
|
||||
|
||||
def _make_val_batch_with_train_batch_settings(self, data_root, reference_train_batch, name='val'):
|
||||
batch_size = reference_train_batch.batch_size
|
||||
seed = reference_train_batch.seed
|
||||
args = Namespace(
|
||||
aspects=aspects.get_aspect_buckets(512),
|
||||
flip_p=0.0,
|
||||
seed=seed,
|
||||
)
|
||||
image_train_items = resolver.resolve(data_root, args)
|
||||
data_loader = DataLoaderMultiAspect(
|
||||
image_train_items,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
)
|
||||
val_batch = EveryDreamBatch(
|
||||
data_loader=data_loader,
|
||||
debug_level=1,
|
||||
batch_size=batch_size,
|
||||
conditional_dropout=0,
|
||||
tokenizer=reference_train_batch.tokenizer,
|
||||
seed=seed,
|
||||
name=name,
|
||||
)
|
||||
return val_batch
|
|
@ -0,0 +1,73 @@
|
|||
# copy/pasted from pytorch lightning
|
||||
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
|
||||
# and
|
||||
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
|
||||
|
||||
# Copyright The Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Dict, Any
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import getstate as python_get_rng_state
|
||||
from random import setstate as python_set_rng_state
|
||||
|
||||
|
||||
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
||||
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
||||
states = {
|
||||
"torch": torch.get_rng_state(),
|
||||
"numpy": np.random.get_state(),
|
||||
"python": python_get_rng_state(),
|
||||
}
|
||||
if include_cuda:
|
||||
states["torch.cuda"] = torch.cuda.get_rng_state_all()
|
||||
return states
|
||||
|
||||
|
||||
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
||||
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
||||
process."""
|
||||
torch.set_rng_state(rng_state_dict["torch"])
|
||||
# torch.cuda rng_state is only included since v1.8.
|
||||
if "torch.cuda" in rng_state_dict:
|
||||
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
|
||||
np.random.set_state(rng_state_dict["numpy"])
|
||||
version, state, gauss = rng_state_dict["python"]
|
||||
python_set_rng_state((version, tuple(state), gauss))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
|
||||
"""A context manager that resets the global random state on exit to what it was before entering.
|
||||
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
|
||||
Args:
|
||||
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
|
||||
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
|
||||
prohibited.
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
|
||||
<torch._C.Generator object at ...>
|
||||
>>> with isolate_rng():
|
||||
... [torch.rand(1) for _ in range(3)]
|
||||
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
|
||||
>>> torch.rand(1)
|
||||
tensor([0.7576])
|
||||
"""
|
||||
states = _collect_rng_states(include_cuda)
|
||||
yield
|
||||
_set_rng_states(states)
|
|
@ -0,0 +1,65 @@
|
|||
import argparse
|
||||
import math
|
||||
import os.path
|
||||
import random
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
|
||||
|
||||
|
||||
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
|
||||
for directory, _, filenames in os.walk(root_dir):
|
||||
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
|
||||
for image_filename in image_filenames:
|
||||
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
|
||||
image_path = os.path.join(directory+image_filename)
|
||||
caption_path = os.path.join(directory+caption_filename)
|
||||
yield (image_path, caption_path if os.path.exists(caption_path) else None)
|
||||
|
||||
|
||||
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
|
||||
image_path = image_caption_pair[0]
|
||||
caption_path = image_caption_pair[1]
|
||||
|
||||
# make target folder if necessary
|
||||
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root))
|
||||
target_folder = os.path.join(target_root, relative_folder)
|
||||
os.makedirs(target_folder, exist_ok=True)
|
||||
|
||||
# copy files
|
||||
shutil.copy2(image_path, os.path.join(target_folder, os.path.basename(image_path)))
|
||||
if caption_path is not None:
|
||||
shutil.copy2(caption_path, os.path.join(target_folder, os.path.basename(caption_path)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('source_root', type=str, help='Source root folder containing images')
|
||||
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
|
||||
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
|
||||
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
|
||||
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
|
||||
args = parser.parse_args()
|
||||
|
||||
images = gather_captioned_images(args.source_root)
|
||||
print(f"Found {len(images)} captioned images in {args.source_root}")
|
||||
val_split_count = math.ceil(len(images) * args.split_proportion)
|
||||
if val_split_count == 0:
|
||||
raise ValueError(f"No images in validation split with source count {len(images)} and split proportion {args.split_proportion}")
|
||||
|
||||
random.seed(args.seed)
|
||||
random.shuffle(images)
|
||||
val_split = images[0:val_split_count]
|
||||
train_split = images[val_split_count:]
|
||||
print(f"Split to 'train' set with {len(train_split)} images and 'val' set with {len(val_split)}")
|
||||
print(f"Copying 'val' set to {args.val_output_folder}...")
|
||||
for v in tqdm(val_split):
|
||||
copy_captioned_image(v, args.source_root, args.val_output_folder)
|
||||
print(f"Copying 'train' set to {args.train_output_folder}...")
|
||||
for v in tqdm(train_split):
|
||||
copy_captioned_image(v, args.source_root, args.train_output_folder)
|
||||
print("Done.")
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"stabilize_training_loss": "If true, stabilise the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
},
|
||||
"validate_training": true,
|
||||
"val_split_mode": "automatic",
|
||||
"val_data_root": null,
|
||||
"val_split_proportion": 0.15,
|
||||
"stabilize_training_loss": true,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
}
|
Loading…
Reference in New Issue