data loader tweaks
This commit is contained in:
parent
bc7b95a375
commit
f90d8e5b53
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
Copyright [2022-2223] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
|
@ -39,7 +39,6 @@ class DataLoaderMultiAspect():
|
|||
self.batch_size = batch_size
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
self.expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||
if self.expected_epoch_size != len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.expected_epoch_size} images.")
|
||||
|
@ -48,8 +47,6 @@ class DataLoaderMultiAspect():
|
|||
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
self.__update_rating_sums()
|
||||
|
||||
|
||||
def __pick_multiplied_set(self, randomizer: random.Random):
|
||||
"""
|
||||
|
@ -78,7 +75,7 @@ class DataLoaderMultiAspect():
|
|||
|
||||
return picked_images
|
||||
|
||||
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]:
|
||||
def get_shuffled_image_buckets(self) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Returns the current list of `ImageTrainItem` in randomized order,
|
||||
sorted into buckets with same sized images.
|
||||
|
@ -94,10 +91,7 @@ class DataLoaderMultiAspect():
|
|||
self.seed += 1
|
||||
randomizer = random.Random(self.seed)
|
||||
|
||||
if dropout_fraction < 1.0:
|
||||
picked_images = self.__pick_random_subset(dropout_fraction, randomizer)
|
||||
else:
|
||||
picked_images = self.__pick_multiplied_set(randomizer)
|
||||
picked_images = self.__pick_multiplied_set(randomizer)
|
||||
|
||||
randomizer.shuffle(picked_images)
|
||||
|
||||
|
@ -131,47 +125,3 @@ class DataLoaderMultiAspect():
|
|||
items.extend(buckets[bucket])
|
||||
|
||||
return items
|
||||
|
||||
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
|
||||
"""
|
||||
Picks a random subset of all images
|
||||
- The size of the subset is limited by dropout_faction
|
||||
- The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance
|
||||
:param dropout_fraction: must be between 0.0 and 1.0
|
||||
:param picker: seeded random picker
|
||||
:return: list of picked ImageTrainItem
|
||||
"""
|
||||
|
||||
prepared_train_data = self.prepared_train_data.copy()
|
||||
ratings_summed = self.ratings_summed.copy()
|
||||
rating_overall_sum = self.rating_overall_sum
|
||||
|
||||
num_images = len(prepared_train_data)
|
||||
num_images_to_pick = math.ceil(num_images * dropout_fraction)
|
||||
num_images_to_pick = max(min(num_images_to_pick, num_images), 0)
|
||||
|
||||
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
|
||||
|
||||
picked_images: list[ImageTrainItem] = []
|
||||
while num_images_to_pick > len(picked_images):
|
||||
# find random sample in dataset
|
||||
point = picker.uniform(0.0, rating_overall_sum)
|
||||
pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 )
|
||||
|
||||
# pick random sample
|
||||
picked_image = prepared_train_data[pos]
|
||||
picked_images.append(picked_image)
|
||||
|
||||
# kick picked item out of data set to not pick it again
|
||||
rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0)
|
||||
ratings_summed.pop(pos)
|
||||
prepared_train_data.pop(pos)
|
||||
|
||||
return picked_images
|
||||
|
||||
def __update_rating_sums(self):
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
for item in self.prepared_train_data:
|
||||
self.rating_overall_sum += item.caption.rating()
|
||||
self.ratings_summed.append(self.rating_overall_sum)
|
201
data/dataset.py
201
data/dataset.py
|
@ -1,16 +1,21 @@
|
|||
import cProfile
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from functools import total_ordering
|
||||
from attrs import define, field, Factory
|
||||
from functools import partial
|
||||
from attrs import define, field
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
from utils.fs_helpers import *
|
||||
from typing import Iterable
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from multiprocessing import Pool, Lock
|
||||
|
||||
DEFAULT_MAX_CAPTION_LENGTH = 2048
|
||||
|
||||
def overlay(overlay, base):
|
||||
|
@ -163,12 +168,14 @@ class Dataset:
|
|||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
def __sidecar_cfg(imagepath, fileset):
|
||||
def __sidecar_cfg(imagepath, fileset, lock):
|
||||
cfgs = []
|
||||
for cfgext in ['.txt', '.caption', '.yml', '.yaml']:
|
||||
cfgfile = barename(imagepath) + cfgext
|
||||
if cfgfile in fileset:
|
||||
cfgs.append(ImageConfig.from_file(fileset[cfgfile]))
|
||||
cfg = ImageConfig.from_file(fileset[cfgfile])
|
||||
with lock:
|
||||
cfgs.append(cfg)
|
||||
return ImageConfig.fold(cfgs)
|
||||
|
||||
# Use file name for caption only as a last resort
|
||||
|
@ -179,22 +186,52 @@ class Dataset:
|
|||
cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0])
|
||||
return cfg.merge(cap_cfg)
|
||||
|
||||
@classmethod
|
||||
def scan_one(cls, img, image_configs, fileset, global_cfg, local_cfg, lock):
|
||||
img_cfg = Dataset.__sidecar_cfg(img, fileset, lock)
|
||||
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
with lock:
|
||||
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
|
||||
@classmethod
|
||||
def scan_one_full(cls, img, image_configs, fileset, global_cfg, local_cfg, lock):
|
||||
Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock)
|
||||
img_cfg = Dataset.__sidecar_cfg(img, fileset, lock)
|
||||
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
#print(f"{image_configs[img].main_prompts} {image_configs[img].tags} {image_configs[img].rating}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_root):
|
||||
# Create a visitor that maintains global config stack
|
||||
# and accumulates image configs as it traverses dataset
|
||||
|
||||
image_configs = {}
|
||||
def process_dir(files, parent_globals):
|
||||
#pool = Pool(int(os.cpu_count()/2))
|
||||
lock = Lock()
|
||||
|
||||
fileset = {os.path.basename(f): f for f in files}
|
||||
global_cfg = parent_globals.merge(Dataset.__global_cfg(fileset))
|
||||
local_cfg = Dataset.__local_cfg(fileset)
|
||||
for img in filter(is_image, files):
|
||||
img_cfg = Dataset.__sidecar_cfg(img, fileset)
|
||||
resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
#pool.apply_async(Dataset.scan_one_full, args=(img, image_configs, fileset, global_cfg, local_cfg, lock))
|
||||
Dataset.scan_one_full(img, image_configs, fileset, global_cfg, local_cfg, lock)
|
||||
#Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock)
|
||||
#pool.close()
|
||||
#pool.join()
|
||||
# img_cfg = Dataset.__sidecar_cfg(img, fileset)
|
||||
# resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg])
|
||||
# image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img)
|
||||
|
||||
return global_cfg
|
||||
|
||||
time_start = time.time()
|
||||
walk_and_visit(data_root, process_dir, ImageConfig())
|
||||
time_end = time.time()
|
||||
logging.info(f" ... walk_and_visit took {(time_end - time_start)/60:.2f} minutes and found {len(image_configs)} images")
|
||||
|
||||
return Dataset(image_configs)
|
||||
|
||||
@classmethod
|
||||
|
@ -212,45 +249,125 @@ class Dataset:
|
|||
continue
|
||||
image_configs[img] = cfg
|
||||
return Dataset(image_configs)
|
||||
|
||||
|
||||
def get_one_image_train_item(self, image, aspects, profile=False) -> ImageTrainItem:
|
||||
|
||||
|
||||
config = self.image_configs[image]
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
|
||||
try:
|
||||
if profile:
|
||||
profiler = cProfile.Profile()
|
||||
import random
|
||||
random_n = f"{random.randint(0,999):03d}"
|
||||
profiler.enable()
|
||||
caption = ImageCaption(
|
||||
main_prompt=next(iter(config.main_prompts)),
|
||||
rating=config.rating or 1.0,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
|
||||
use_weights=use_weights)
|
||||
if profile:
|
||||
profiler.disable()
|
||||
profiler.dump_stats(f'profile{random_n}.prof')
|
||||
#exit()
|
||||
|
||||
item = ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=aspects,
|
||||
pathname=os.path.abspath(image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f" *** Error preloading image or caption for: {image}, error: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
return item
|
||||
|
||||
def image_train_items(self, aspects):
|
||||
print(f" * using async loader")
|
||||
run_profiler = False
|
||||
items = []
|
||||
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||
config = self.image_configs[image]
|
||||
process_count = int(os.cpu_count()/2)
|
||||
pool = Pool(process_count)
|
||||
async_results = []
|
||||
|
||||
if len(config.main_prompts) > 1:
|
||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||
time_start = time.time()
|
||||
with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar:
|
||||
for image in self.image_configs:
|
||||
async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects, run_profiler), callback=lambda _: pbar.update())
|
||||
async_results.append(async_result)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
if len(config.main_prompts) < 1:
|
||||
logging.warning(f" *** No main_prompts for image {image}")
|
||||
for async_result in async_results:
|
||||
result = async_result.get()
|
||||
if result is not None:
|
||||
# ImageTrainItem
|
||||
items.append(result)
|
||||
else:
|
||||
raise ValueError(" *** image_train_items(): Async load item missing")
|
||||
|
||||
|
||||
|
||||
time_end = time.time()
|
||||
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
|
||||
return items
|
||||
|
||||
tags = []
|
||||
tag_weights = []
|
||||
for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True):
|
||||
tags.append(tag.value)
|
||||
tag_weights.append(tag.weight)
|
||||
use_weights = len(set(tag_weights)) > 1
|
||||
def image_train_items_newish(self, aspects):
|
||||
print(f" * using async loader")
|
||||
items = []
|
||||
process_count = int(os.cpu_count()/2)
|
||||
pool = Pool(process_count)
|
||||
|
||||
try:
|
||||
caption = ImageCaption(
|
||||
main_prompt=next(iter(config.main_prompts)),
|
||||
rating=config.rating or 1.0,
|
||||
tags=tags,
|
||||
tag_weights=tag_weights,
|
||||
max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH,
|
||||
use_weights=use_weights)
|
||||
time_start = time.time()
|
||||
with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar:
|
||||
async_results = []
|
||||
|
||||
# run 1000 async tasks
|
||||
for image in self.image_configs:
|
||||
# profile the task
|
||||
#cProfile.runctx('self.get_one(image,aspects)', globals(), locals(), 'profile.prof')
|
||||
async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects), callback=lambda _: pbar.update())
|
||||
async_results.append(async_result)
|
||||
pool.close()
|
||||
#pool.join()
|
||||
print(f" * async pool closed")
|
||||
|
||||
item = ImageTrainItem(
|
||||
image=None,
|
||||
caption=caption,
|
||||
aspects=aspects,
|
||||
pathname=os.path.abspath(image),
|
||||
flip_p=config.flip_p or 0.0,
|
||||
multiplier=config.multiply or 1.0,
|
||||
cond_dropout=config.cond_dropout
|
||||
)
|
||||
items.append(item)
|
||||
except Exception as e:
|
||||
logging.error(f" *** Error preloading image or caption for: {image}, error: {e}")
|
||||
raise e
|
||||
return items
|
||||
for async_result in async_results:
|
||||
result = async_result.get()
|
||||
if result is not None:
|
||||
# ImageTrainItem
|
||||
items.append(result)
|
||||
print(f"{result.pathname} {result.caption.main_prompt}")
|
||||
else:
|
||||
raise ValueError(" *** image_train_items(): Async load item missing")
|
||||
|
||||
time_end = time.time()
|
||||
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
|
||||
return items
|
||||
|
||||
def image_train_items_old(self, aspects):
|
||||
print(f" * using single threaded loader")
|
||||
items = []
|
||||
|
||||
time_start = time.time()
|
||||
with tqdm(total=len(self.image_configs), desc="preloading", dynamic_ncols=True) as pbar:
|
||||
for image in self.image_configs:
|
||||
items.append(self.get_one_image_train_item(image, aspects))
|
||||
pbar.update()
|
||||
time_end = time.time()
|
||||
logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images")
|
||||
return items
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
|
@ -57,11 +57,11 @@ class EveryDreamBatch(Dataset):
|
|||
self.retain_contrast = retain_contrast
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
#self.rated_dataset = rated_dataset
|
||||
#self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
# First epoch always trains on all images
|
||||
self.image_train_items = []
|
||||
self.__update_image_train_items(1.0)
|
||||
self.__update_image_train_items()
|
||||
self.name = name
|
||||
|
||||
num_images = len(self.image_train_items)
|
||||
|
@ -69,13 +69,7 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
def shuffle(self, epoch_n: int, max_epochs: int):
|
||||
self.seed += 1
|
||||
|
||||
if self.rated_dataset:
|
||||
dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs
|
||||
else:
|
||||
dropout_fraction = 1.0
|
||||
|
||||
self.__update_image_train_items(dropout_fraction)
|
||||
self.__update_image_train_items()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_train_items)
|
||||
|
@ -140,8 +134,8 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
def __update_image_train_items(self):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets()
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -56,6 +56,9 @@ class ImageCaption:
|
|||
if use_weights and len(tag_weights) > len(tags):
|
||||
self.__tag_weights = tag_weights[:len(tags)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ImageCaption({self.__main_prompt}, {self.__rating}, {self.__tags}, {self.__tag_weights}, {self.__max_target_length}, {self.__use_weights})"
|
||||
|
||||
def rating(self) -> float:
|
||||
return self.__rating
|
||||
|
||||
|
@ -143,7 +146,6 @@ class ImageTrainItem:
|
|||
else:
|
||||
self.image = image
|
||||
self.image_size = image.size
|
||||
self.target_size = None
|
||||
|
||||
self.is_undersized = False
|
||||
self.error = None
|
||||
|
@ -245,7 +247,7 @@ class ImageTrainItem:
|
|||
self.target_wh = None
|
||||
try:
|
||||
with PIL.Image.open(self.pathname) as image:
|
||||
image = self._try_transpose(image, print_error=True).convert('RGB')
|
||||
image = self._try_transpose(image, print_error=True)
|
||||
width, height = image.size
|
||||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
|
12
train.py
12
train.py
|
@ -241,8 +241,8 @@ def setup_args(args):
|
|||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
|
||||
#if args.useadam8bit:
|
||||
# logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
|
@ -932,7 +932,7 @@ def main(args):
|
|||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
#gc.collect()
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
|
@ -1011,12 +1011,12 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
#argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
#argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
#argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
|
||||
# load CLI args to overwrite existing config args
|
||||
|
|
|
@ -25,7 +25,7 @@ def read_float(file):
|
|||
try:
|
||||
return float(read_text(file))
|
||||
except Exception as e:
|
||||
logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}")
|
||||
logging.warning(f" *** Could not parse number to float in file {file}: {e}")
|
||||
|
||||
import os
|
||||
|
||||
|
@ -48,4 +48,4 @@ def walk_and_visit(path, visit_fn, context=None):
|
|||
subcontext = visit_fn(files, context)
|
||||
|
||||
for subdir in dirs:
|
||||
walk_and_visit(subdir, visit_fn, subcontext)
|
||||
walk_and_visit(subdir, visit_fn, subcontext)
|
||||
|
|
Loading…
Reference in New Issue