Merge pull request #50 from noprompt/add-support-for-val-split

Add support for validation split
This commit is contained in:
Victor Hall 2023-02-07 18:16:34 -05:00 committed by GitHub
commit 165525a71c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 473 additions and 173 deletions

View File

@ -14,11 +14,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import bisect
from functools import reduce
import math
import copy
import random
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, ImageCaption
import PIL
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -27,23 +28,19 @@ class DataLoaderMultiAspect():
"""
Data loader for multi-aspect-ratio training and bucketing
image_train_items: list of `lImageTrainItem` objects
image_train_items: list of `ImageTrainItem` objects
seed: random seed
batch_size: number of images per batch
"""
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
self.seed = seed
self.batch_size = batch_size
# Prepare data
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
# Initialize ratings
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for image in self.prepared_train_data:
self.rating_overall_sum += image.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer):
"""
@ -80,14 +77,17 @@ class DataLoaderMultiAspect():
del data_copy
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
"""
returns the current list of images including their captions in a randomized order,
sorted into buckets with same sized images
if dropout_fraction < 1.0, only a subset of the images will be returned
if dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan
Returns the current list of `ImageTrainItem` in randomized order,
sorted into buckets with same sized images.
If dropout_fraction < 1.0, only a subset of the images will be returned.
If dropout_fraction >= 1.0, repicks fractional multipliers based on folder/multiply.txt values swept at prescan.
:param dropout_fraction: must be between 0.0 and 1.0.
:return: randomized list of (image, caption) pairs, sorted into same sized buckets
:return: Randomized list of `ImageTrainItem` objects
"""
self.seed += 1
@ -126,11 +126,11 @@ class DataLoaderMultiAspect():
buckets[bucket].extend(runt_bucket)
# flatten the buckets
image_caption_pairs = []
items: list[ImageTrainItem] = []
for bucket in buckets:
image_caption_pairs.extend(buckets[bucket])
items.extend(buckets[bucket])
return image_caption_pairs
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
@ -168,3 +168,10 @@ class DataLoaderMultiAspect():
prepared_train_data.pop(pos)
return picked_images
def __update_rating_sums(self):
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)

View File

@ -1,2 +0,0 @@
# stop lightning's repeated instantiation of batch train/val/test classes causing multiple sweeps of the same data off disk
shared_dataloader = None

View File

@ -1,70 +0,0 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from torch.utils.data import Dataset
from ldm.data.data_loader import DataLoaderMultiAspect as dlma
import math
import ldm.data.dl_singleton as dls
from ldm.data.image_train_item import ImageTrainItem
class EDValidateBatch(Dataset):
def __init__(self,
data_root,
flip_p=0.0,
repeats=1,
debug_level=0,
batch_size=1,
set='val',
):
self.data_root = data_root
self.batch_size = batch_size
if not dls.shared_dataloader:
print("Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p)
self.image_train_items = dls.shared_dataloader.get_all_images()
self.num_images = len(self.image_train_items)
self._length = max(math.trunc(self.num_images * repeats), batch_size) - self.num_images % self.batch_size
print()
print(f" ** Validation Set: {set}, steps: {self._length / batch_size:.0f}, repeats: {repeats} ")
print()
def __len__(self):
return self._length
def __getitem__(self, i):
idx = i % self.num_images
image_train_item = self.image_train_items[idx]
example = self.__get_image_for_trainer(image_train_item)
return example
@staticmethod
def __get_image_for_trainer(image_train_item: ImageTrainItem):
example = {}
image_train_tmp = image_train_item.hydrate()
example["image"] = image_train_tmp.image
example["caption"] = image_train_tmp.caption
return example

View File

@ -41,7 +41,8 @@ class EveryDreamBatch(Dataset):
retain_contrast=False,
shuffle_tags=False,
rated_dataset=False,
rated_dataset_dropout_target=0.5
rated_dataset_dropout_target=0.5,
name='train'
):
self.data_loader = data_loader
self.batch_size = data_loader.batch_size
@ -57,10 +58,18 @@ class EveryDreamBatch(Dataset):
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
# First epoch always trains on all images
self.image_train_items = self.data_loader.get_shuffled_image_buckets(1.0)
self.image_train_items = []
self.__update_image_train_items(1.0)
self.name = name
num_images = len(self.image_train_items)
logging.info(f" ** Trainer Set: {num_images / self.batch_size:.0f}, num_images: {num_images}, batch_size: {self.batch_size}")
logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}")
def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]:
items = self.data_loader.get_random_split(split_proportion, remove_from_dataset)
self.__update_image_train_items(1.0)
return items
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
@ -69,8 +78,8 @@ class EveryDreamBatch(Dataset):
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
self.__update_image_train_items(dropout_fraction)
def __len__(self):
return len(self.image_train_items)
@ -130,3 +139,38 @@ class EveryDreamBatch(Dataset):
example["caption"] = image_train_tmp.caption
example["runt_size"] = image_train_tmp.runt_size
return example
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
collate_fn=collate_fn
)
return dataloader
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret

View File

@ -0,0 +1,173 @@
import json
import math
import random
from typing import Callable, Any, Optional
from argparse import Namespace
import torch
from colorama import Fore, Style
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from data.every_dream import build_torch_dataloader, EveryDreamBatch
from data.data_loader import DataLoaderMultiAspect
from data import resolver
from data import aspects
from data.image_train_item import ImageTrainItem
from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)
split_items = list(items_copy[:split_item_count])
remaining_items = list(items_copy[split_item_count:])
return split_items, remaining_items
class EveryDreamValidator:
def __init__(self,
val_config_path: Optional[str],
default_batch_size: int,
log_writer: SummaryWriter):
self.val_dataloader = None
self.train_overlapping_dataloader = None
self.log_writer = log_writer
self.config = {
'batch_size': default_batch_size,
'every_n_epochs': 1,
'seed': 555,
'val_split_mode': 'automatic',
'val_split_proportion': 0.15,
'stabilize_training_loss': False,
'stabilize_split_proportion': 0.15
}
if val_config_path is not None:
with open(val_config_path, 'rt') as f:
self.config.update(json.load(f))
@property
def batch_size(self):
return self.config['batch_size']
@property
def every_n_epochs(self):
return self.config['every_n_epochs']
@property
def seed(self):
return self.config['seed']
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
"""
Build the validation splits as requested by the config passed at init.
This may steal some items from `train_items`.
If this happens, the returned `list` contains the remaining items after the required items have been stolen.
Otherwise, the returned `list` is identical to the passed-in `train_items`.
"""
with isolate_rng():
self.val_dataloader, remaining_train_items = self._build_val_dataloader_if_required(train_items, tokenizer)
# order is important - if we're removing images from train, this needs to happen before making
# the overlapping dataloader
self.train_overlapping_dataloader = self._build_train_stabilizer_dataloader_if_required(
remaining_train_items, tokenizer)
return remaining_train_items
def do_validation_if_appropriate(self, epoch: int, global_step: int,
get_model_prediction_and_target_callable: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
if (epoch % self.every_n_epochs) == 0:
if self.train_overlapping_dataloader is not None:
self._do_validation('stabilize-train', global_step, self.train_overlapping_dataloader,
get_model_prediction_and_target_callable)
if self.val_dataloader is not None:
self._do_validation('val', global_step, self.val_dataloader, get_model_prediction_and_target_callable)
def _do_validation(self, tag, global_step, dataloader, get_model_prediction_and_target: Callable[
[Any, Any], tuple[torch.Tensor, torch.Tensor]]):
with torch.no_grad(), isolate_rng():
loss_validation_epoch = []
steps_pbar = tqdm(range(len(dataloader)), position=1)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Validate ({tag}){Style.RESET_ALL}")
for step, batch in enumerate(dataloader):
# ok to override seed here because we are in a `with isolate_rng():` block
torch.manual_seed(self.seed + step)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
loss_step = loss.detach().item()
loss_validation_epoch.append(loss_step)
steps_pbar.update(1)
steps_pbar.close()
loss_validation_local = sum(loss_validation_epoch) / len(loss_validation_epoch)
self.log_writer.add_scalar(tag=f"loss/{tag}", scalar_value=loss_validation_local, global_step=global_step)
def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\
-> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]:
val_split_mode = self.config['val_split_mode']
val_split_proportion = self.config['val_split_proportion']
remaining_train_items = image_train_items
if val_split_mode == 'none':
return None, image_train_items
elif val_split_mode == 'automatic':
val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size)
elif val_split_mode == 'manual':
args = Namespace(
aspects=aspects.get_aspect_buckets(512),
flip_p=0.0,
seed=self.seed,
)
val_data_root = self.config['val_data_root']
val_items = resolver.resolve_root(val_data_root, args)
else:
raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'")
val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val')
val_dataloader = build_torch_dataloader(val_ed_batch, batch_size=self.batch_size)
return val_dataloader, remaining_train_items
def _build_train_stabilizer_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer) \
-> Optional[torch.utils.data.DataLoader]:
stabilize_training_loss = self.config['stabilize_training_loss']
if not stabilize_training_loss:
return None
stabilize_split_proportion = self.config['stabilize_split_proportion']
stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size)
stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer,
name='stabilize-train')
stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size)
return stabilize_dataloader
def _build_ed_batch(self, items: list[ImageTrainItem], batch_size: int, tokenizer, name='val'):
batch_size = self.batch_size
seed = self.seed
data_loader = DataLoaderMultiAspect(
items,
batch_size=batch_size,
seed=seed,
)
ed_batch = EveryDreamBatch(
data_loader=data_loader,
debug_level=1,
conditional_dropout=0,
tokenizer=tokenizer,
seed=seed,
name=name,
)
return ed_batch

150
train.py
View File

@ -52,7 +52,8 @@ import wandb
from torch.utils.tensorboard import SummaryWriter
from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch
from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
@ -349,29 +350,6 @@ def read_sample_prompts(sample_prompts_file_path: str):
return sample_prompts
def collate_fn(batch):
"""
Collates batches
"""
images = [example["image"] for example in batch]
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def main(args):
"""
Main entry point
@ -387,10 +365,13 @@ def main(args):
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
logging.info(f" Seed: {seed}")
set_seed(seed)
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
gpu = GPU()
device = torch.device(f"cuda:{args.gpuid}")
torch.backends.cudnn.benchmark = True
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu'
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -606,6 +587,11 @@ def main(args):
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
betas = (0.9, 0.999)
epsilon = 1e-8
if args.amp:
@ -630,9 +616,14 @@ def main(args):
)
log_optimizer(optimizer, betas, epsilon)
image_train_items = resolve_image_train_items(args, log_folder)
validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size)
# the validation dataset may need to steal some items from image_train_items
image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer)
data_loader = DataLoaderMultiAspect(
image_train_items=image_train_items,
seed=seed,
@ -668,12 +659,7 @@ def main(args):
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, )
log_writer = SummaryWriter(
log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
)
def log_args(log_writer, args):
arglog = "args:\n"
@ -729,15 +715,7 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
train_dataloader = torch.utils.data.DataLoader(
train_batch,
batch_size=args.batch_size,
shuffle=False,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True
)
train_dataloader = build_torch_dataloader(train_batch, batch_size=args.batch_size)
unet.train() if not args.disable_unet_training else unet.eval()
text_encoder.train() if not args.disable_textenc_training else text_encoder.eval()
@ -775,7 +753,49 @@ def main(args):
loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = tokens.to(text_encoder.device)
# with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(
encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
@ -809,41 +829,7 @@ def main(args):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
latents = vae.encode(pixel_values, return_dict=False)
del pixel_values
latents = latents[0].sample() * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
#with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
if args.clip_skip > 0:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states.hidden_states[-args.clip_skip])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type in ["v_prediction", "v-prediction"]:
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del noise, latents, cuda_caption
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"])
#del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
@ -952,6 +938,10 @@ def main(args):
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
# validate
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch

73
utils/isolate_rng.py Normal file
View File

@ -0,0 +1,73 @@
# copy/pasted from pytorch lightning
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
# and
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Generator, Dict, Any
import torch
import numpy as np
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
states = {
"torch": torch.get_rng_state(),
"numpy": np.random.get_state(),
"python": python_get_rng_state(),
}
if include_cuda:
states["torch.cuda"] = torch.cuda.get_rng_state_all()
return states
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
process."""
torch.set_rng_state(rng_state_dict["torch"])
# torch.cuda rng_state is only included since v1.8.
if "torch.cuda" in rng_state_dict:
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
np.random.set_state(rng_state_dict["numpy"])
version, state, gauss = rng_state_dict["python"]
python_set_rng_state((version, tuple(state), gauss))
@contextmanager
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
"""A context manager that resets the global random state on exit to what it was before entering.
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
Args:
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
prohibited.
Example:
>>> import torch
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
<torch._C.Generator object at ...>
>>> with isolate_rng():
... [torch.rand(1) for _ in range(3)]
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
>>> torch.rand(1)
tensor([0.7576])
"""
states = _collect_rng_states(include_cuda)
yield
_set_rng_states(states)

65
utils/split_dataset.py Normal file
View File

@ -0,0 +1,65 @@
import argparse
import math
import os.path
import random
import shutil
from typing import Optional
from tqdm.auto import tqdm
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
for directory, _, filenames in os.walk(root_dir):
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
for image_filename in image_filenames:
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
image_path = os.path.join(directory+image_filename)
caption_path = os.path.join(directory+caption_filename)
yield (image_path, caption_path if os.path.exists(caption_path) else None)
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
image_path = image_caption_pair[0]
caption_path = image_caption_pair[1]
# make target folder if necessary
relative_folder = os.path.dirname(os.path.relpath(image_path, source_root))
target_folder = os.path.join(target_root, relative_folder)
os.makedirs(target_folder, exist_ok=True)
# copy files
shutil.copy2(image_path, os.path.join(target_folder, os.path.basename(image_path)))
if caption_path is not None:
shutil.copy2(caption_path, os.path.join(target_folder, os.path.basename(caption_path)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('source_root', type=str, help='Source root folder containing images')
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
args = parser.parse_args()
images = gather_captioned_images(args.source_root)
print(f"Found {len(images)} captioned images in {args.source_root}")
val_split_count = math.ceil(len(images) * args.split_proportion)
if val_split_count == 0:
raise ValueError(f"No images in validation split with source count {len(images)} and split proportion {args.split_proportion}")
random.seed(args.seed)
random.shuffle(images)
val_split = images[0:val_split_count]
train_split = images[val_split_count:]
print(f"Split to 'train' set with {len(train_split)} images and 'val' set with {len(val_split)}")
print(f"Copying 'val' set to {args.val_output_folder}...")
for v in tqdm(val_split):
copy_captioned_image(v, args.source_root, args.val_output_folder)
print(f"Copying 'train' set to {args.train_output_folder}...")
for v in tqdm(train_split):
copy_captioned_image(v, args.source_root, args.train_output_folder)
print("Done.")

20
validation_default.json Normal file
View File

@ -0,0 +1,20 @@
{
"documentation": {
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes."
},
"validate_training": true,
"val_split_mode": "automatic",
"val_data_root": null,
"val_split_proportion": 0.15,
"stabilize_training_loss": false,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,
"seed": 555
}