data loader tweaks

This commit is contained in:
Victor Hall 2023-04-18 22:11:51 -04:00
parent bc7b95a375
commit f90d8e5b53
6 changed files with 181 additions and 118 deletions

View File

@ -1,5 +1,5 @@
"""
Copyright [2022] Victor C Hall
Copyright [2022-2223] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
@ -39,7 +39,6 @@ class DataLoaderMultiAspect():
self.batch_size = batch_size
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
self.expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if self.expected_epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.expected_epoch_size} images.")
@ -48,8 +47,6 @@ class DataLoaderMultiAspect():
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer: random.Random):
"""
@ -78,7 +75,7 @@ class DataLoaderMultiAspect():
return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
def get_shuffled_image_buckets(self) -> list[ImageTrainItem]:
"""
Returns the current list of `ImageTrainItem` in randomized order,
sorted into buckets with same sized images.
@ -94,10 +91,7 @@ class DataLoaderMultiAspect():
self.seed += 1
randomizer = random.Random(self.seed)
if dropout_fraction < 1.0:
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
else:
picked_images = self.__pick_multiplied_set(randomizer)
picked_images = self.__pick_multiplied_set(randomizer)
randomizer.shuffle(picked_images)
@ -131,47 +125,3 @@ class DataLoaderMultiAspect():
items.extend(buckets[bucket])
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
:param dropout_fraction: must be between 0.0 and 1.0
:param picker: seeded random picker
:return: list of picked ImageTrainItem
"""
prepared_train_data = self.prepared_train_data.copy()
ratings_summed = self.ratings_summed.copy()
rating_overall_sum = self.rating_overall_sum
num_images = len(prepared_train_data)
num_images_to_pick = math.ceil(num_images * dropout_fraction)
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
picked_images: list[ImageTrainItem] = []
while num_images_to_pick > len(picked_images):
# find random sample in dataset
point = picker.uniform(0.0, rating_overall_sum)
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
# pick random sample
picked_image = prepared_train_data[pos]
picked_images.append(picked_image)
# kick picked item out of data set to not pick it again
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
ratings_summed.pop(pos)
prepared_train_data.pop(pos)
return picked_images
def __update_rating_sums(self):
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)

View File

@ -1,16 +1,21 @@
import cProfile
from contextlib import nullcontext
import os
import logging
import time
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from functools import partial
from attrs import define, field
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import Iterable
from tqdm import tqdm
from multiprocessing import Pool, Lock
DEFAULT_MAX_CAPTION_LENGTH = 2048
def overlay(overlay, base):
@ -163,12 +168,14 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
return ImageConfig.fold(cfgs)
def __sidecar_cfg(imagepath, fileset):
def __sidecar_cfg(imagepath, fileset, lock):
cfgs = []
for cfgext in ['.txt', '.caption', '.yml', '.yaml']:
cfgfile = barename(imagepath) + cfgext
if cfgfile in fileset:
cfgs.append(ImageConfig.from_file(fileset[cfgfile]))
cfg = ImageConfig.from_file(fileset[cfgfile])
with lock:
cfgs.append(cfg)
return ImageConfig.fold(cfgs)
# Use file name for caption only as a last resort
@ -179,22 +186,52 @@ class Dataset:
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
return cfg.merge(cap_cfg)
@classmethod
def scan_one(cls, img, image_configs, fileset, global_cfg, local_cfg, lock):
img_cfg = Dataset.__sidecar_cfg(img, fileset, lock)
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
with lock:
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
@classmethod
def scan_one_full(cls, img, image_configs, fileset, global_cfg, local_cfg, lock):
Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock)
img_cfg = Dataset.__sidecar_cfg(img, fileset, lock)
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
#print(f"{image_configs[img].main_prompts} {image_configs[img].tags} {image_configs[img].rating}")
@classmethod
def from_path(cls, data_root):
# Create a visitor that maintains global config stack
# and accumulates image configs as it traverses dataset
image_configs = {}
def process_dir(files, parent_globals):
#pool = Pool(int(os.cpu_count()/2))
lock = Lock()
fileset = {os.path.basename(f): f for f in files}
global_cfg = parent_globals.merge(Dataset.__global_cfg(fileset))
local_cfg = Dataset.__local_cfg(fileset)
for img in filter(is_image, files):
img_cfg = Dataset.__sidecar_cfg(img, fileset)
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
#pool.apply_async(Dataset.scan_one_full, args=(img, image_configs, fileset, global_cfg, local_cfg, lock))
Dataset.scan_one_full(img, image_configs, fileset, global_cfg, local_cfg, lock)
#Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock)
#pool.close()
#pool.join()
# img_cfg = Dataset.__sidecar_cfg(img, fileset)
# resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
# image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
return global_cfg
time_start = time.time()
walk_and_visit(data_root, process_dir, ImageConfig())
time_end = time.time()
logging.info(f" ... walk_and_visit took {(time_end - time_start)/60:.2f} minutes and found {len(image_configs)} images")
return Dataset(image_configs)
@classmethod
@ -212,45 +249,125 @@ class Dataset:
continue
image_configs[img] = cfg
return Dataset(image_configs)
def get_one_image_train_item(self, image, aspects, profile=False) -> ImageTrainItem:
config = self.image_configs[image]
tags = []
tag_weights = []
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
tags.append(tag.value)
tag_weights.append(tag.weight)
use_weights = len(set(tag_weights)) > 1
try:
if profile:
profiler = cProfile.Profile()
import random
random_n = f"{random.randint(0,999):03d}"
profiler.enable()
caption = ImageCaption(
main_prompt=next(iter(config.main_prompts)),
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
use_weights=use_weights)
if profile:
profiler.disable()
profiler.dump_stats(f'profile{random_n}.prof')
#exit()
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
except Exception as e:
logging.error(f" *** Error preloading image or caption for: {image}, error: {e}")
raise e
return item
def image_train_items(self, aspects):
print(f" * using async loader")
run_profiler = False
items = []
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
config = self.image_configs[image]
process_count = int(os.cpu_count()/2)
pool = Pool(process_count)
async_results = []
if len(config.main_prompts) > 1:
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
time_start = time.time()
with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar:
for image in self.image_configs:
async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects, run_profiler), callback=lambda _: pbar.update())
async_results.append(async_result)
pool.close()
pool.join()
if len(config.main_prompts) < 1:
logging.warning(f" *** No main_prompts for image {image}")
for async_result in async_results:
result = async_result.get()
if result is not None:
# ImageTrainItem
items.append(result)
else:
raise ValueError(" *** image_train_items(): Async load item missing")
time_end = time.time()
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
return items
tags = []
tag_weights = []
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
tags.append(tag.value)
tag_weights.append(tag.weight)
use_weights = len(set(tag_weights)) > 1
def image_train_items_newish(self, aspects):
print(f" * using async loader")
items = []
process_count = int(os.cpu_count()/2)
pool = Pool(process_count)
try:
caption = ImageCaption(
main_prompt=next(iter(config.main_prompts)),
rating=config.rating or 1.0,
tags=tags,
tag_weights=tag_weights,
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
use_weights=use_weights)
time_start = time.time()
with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar:
async_results = []
# run 1000 async tasks
for image in self.image_configs:
# profile the task
#cProfile.runctx('self.get_one(image,aspects)', globals(), locals(), 'profile.prof')
async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects), callback=lambda _: pbar.update())
async_results.append(async_result)
pool.close()
#pool.join()
print(f" * async pool closed")
item = ImageTrainItem(
image=None,
caption=caption,
aspects=aspects,
pathname=os.path.abspath(image),
flip_p=config.flip_p or 0.0,
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout
)
items.append(item)
except Exception as e:
logging.error(f" *** Error preloading image or caption for: {image}, error: {e}")
raise e
return items
for async_result in async_results:
result = async_result.get()
if result is not None:
# ImageTrainItem
items.append(result)
print(f"{result.pathname} {result.caption.main_prompt}")
else:
raise ValueError(" *** image_train_items(): Async load item missing")
time_end = time.time()
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
return items
def image_train_items_old(self, aspects):
print(f" * using single threaded loader")
items = []
time_start = time.time()
with tqdm(total=len(self.image_configs), desc="preloading", dynamic_ncols=True) as pbar:
for image in self.image_configs:
items.append(self.get_one_image_train_item(image, aspects))
pbar.update()
time_end = time.time()
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
return items

View File

@ -1,5 +1,5 @@
"""
Copyright [2022] Victor C Hall
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
@ -57,11 +57,11 @@ class EveryDreamBatch(Dataset):
self.retain_contrast = retain_contrast
self.shuffle_tags = shuffle_tags
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
#self.rated_dataset = rated_dataset
#self.rated_dataset_dropout_target = rated_dataset_dropout_target
# First epoch always trains on all images
self.image_train_items = []
self.__update_image_train_items(1.0)
self.__update_image_train_items()
self.name = name
num_images = len(self.image_train_items)
@ -69,13 +69,7 @@ class EveryDreamBatch(Dataset):
def shuffle(self, epoch_n: int, max_epochs: int):
self.seed += 1
if self.rated_dataset:
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
else:
dropout_fraction = 1.0
self.__update_image_train_items(dropout_fraction)
self.__update_image_train_items()
def __len__(self):
return len(self.image_train_items)
@ -140,8 +134,8 @@ class EveryDreamBatch(Dataset):
return example
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def __update_image_train_items(self):
self.image_train_items = self.data_loader.get_shuffled_image_buckets()
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(

View File

@ -56,6 +56,9 @@ class ImageCaption:
if use_weights and len(tag_weights) > len(tags):
self.__tag_weights = tag_weights[:len(tags)]
def __repr__(self) -> str:
return f"ImageCaption({self.__main_prompt}, {self.__rating}, {self.__tags}, {self.__tag_weights}, {self.__max_target_length}, {self.__use_weights})"
def rating(self) -> float:
return self.__rating
@ -143,7 +146,6 @@ class ImageTrainItem:
else:
self.image = image
self.image_size = image.size
self.target_size = None
self.is_undersized = False
self.error = None
@ -245,7 +247,7 @@ class ImageTrainItem:
self.target_wh = None
try:
with PIL.Image.open(self.pathname) as image:
image = self._try_transpose(image, print_error=True).convert('RGB')
image = self._try_transpose(image, print_error=True)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))

View File

@ -241,8 +241,8 @@ def setup_args(args):
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit:
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
#if args.useadam8bit:
# logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
@ -932,7 +932,7 @@ def main(args):
if validator:
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
gc.collect()
#gc.collect()
# end of epoch
# end of training
@ -1011,12 +1011,12 @@ if __name__ == "__main__":
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
#argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
#argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
#argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
# load CLI args to overwrite existing config args

View File

@ -25,7 +25,7 @@ def read_float(file):
try:
return float(read_text(file))
except Exception as e:
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
logging.warning(f" *** Could not parse number to float in file {file}: {e}")
import os
@ -48,4 +48,4 @@ def walk_and_visit(path, visit_fn, context=None):
subcontext = visit_fn(files, context)
for subdir in dirs:
walk_and_visit(subdir, visit_fn, subcontext)
walk_and_visit(subdir, visit_fn, subcontext)