diff --git a/data/data_loader.py b/data/data_loader.py index 9d9f7c4..d77ce74 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,5 +1,5 @@ """ -Copyright [2022] Victor C Hall +Copyright [2022-2223] Victor C Hall Licensed under the GNU Affero General Public License; You may not use this code except in compliance with the License. @@ -39,7 +39,6 @@ class DataLoaderMultiAspect(): self.batch_size = batch_size self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) - self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) self.expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) if self.expected_epoch_size != len(self.prepared_train_data): logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.expected_epoch_size} images.") @@ -48,8 +47,6 @@ class DataLoaderMultiAspect(): self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] - self.__update_rating_sums() - def __pick_multiplied_set(self, randomizer: random.Random): """ @@ -78,7 +75,7 @@ class DataLoaderMultiAspect(): return picked_images - def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]: + def get_shuffled_image_buckets(self) -> list[ImageTrainItem]: """ Returns the current list of `ImageTrainItem` in randomized order, sorted into buckets with same sized images. @@ -94,10 +91,7 @@ class DataLoaderMultiAspect(): self.seed += 1 randomizer = random.Random(self.seed) - if dropout_fraction < 1.0: - picked_images = self.__pick_random_subset(dropout_fraction, randomizer) - else: - picked_images = self.__pick_multiplied_set(randomizer) + picked_images = self.__pick_multiplied_set(randomizer) randomizer.shuffle(picked_images) @@ -131,47 +125,3 @@ class DataLoaderMultiAspect(): items.extend(buckets[bucket]) return items - - def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: - """ - Picks a random subset of all images - - The size of the subset is limited by dropout_faction - - The chance of an image to be picked is influenced by its rating. Double that rating -> double the chance - :param dropout_fraction: must be between 0.0 and 1.0 - :param picker: seeded random picker - :return: list of picked ImageTrainItem - """ - - prepared_train_data = self.prepared_train_data.copy() - ratings_summed = self.ratings_summed.copy() - rating_overall_sum = self.rating_overall_sum - - num_images = len(prepared_train_data) - num_images_to_pick = math.ceil(num_images * dropout_fraction) - num_images_to_pick = max(min(num_images_to_pick, num_images), 0) - - # logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}") - - picked_images: list[ImageTrainItem] = [] - while num_images_to_pick > len(picked_images): - # find random sample in dataset - point = picker.uniform(0.0, rating_overall_sum) - pos = min(bisect.bisect_left(ratings_summed, point), len(prepared_train_data) -1 ) - - # pick random sample - picked_image = prepared_train_data[pos] - picked_images.append(picked_image) - - # kick picked item out of data set to not pick it again - rating_overall_sum = max(rating_overall_sum - picked_image.caption.rating(), 0.0) - ratings_summed.pop(pos) - prepared_train_data.pop(pos) - - return picked_images - - def __update_rating_sums(self): - self.rating_overall_sum: float = 0.0 - self.ratings_summed: list[float] = [] - for item in self.prepared_train_data: - self.rating_overall_sum += item.caption.rating() - self.ratings_summed.append(self.rating_overall_sum) \ No newline at end of file diff --git a/data/dataset.py b/data/dataset.py index a0c761a..4538a10 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,16 +1,21 @@ +import cProfile +from contextlib import nullcontext import os import logging +import time import yaml import json -from functools import total_ordering -from attrs import define, field, Factory +from functools import partial +from attrs import define, field from data.image_train_item import ImageCaption, ImageTrainItem from utils.fs_helpers import * from typing import Iterable from tqdm import tqdm +from multiprocessing import Pool, Lock + DEFAULT_MAX_CAPTION_LENGTH = 2048 def overlay(overlay, base): @@ -163,12 +168,14 @@ class Dataset: cfgs.append(ImageConfig.from_file(fileset['local.yml'])) return ImageConfig.fold(cfgs) - def __sidecar_cfg(imagepath, fileset): + def __sidecar_cfg(imagepath, fileset, lock): cfgs = [] for cfgext in ['.txt', '.caption', '.yml', '.yaml']: cfgfile = barename(imagepath) + cfgext if cfgfile in fileset: - cfgs.append(ImageConfig.from_file(fileset[cfgfile])) + cfg = ImageConfig.from_file(fileset[cfgfile]) + with lock: + cfgs.append(cfg) return ImageConfig.fold(cfgs) # Use file name for caption only as a last resort @@ -179,22 +186,52 @@ class Dataset: cap_cfg = ImageConfig.from_caption_text(barename(file).split("_")[0]) return cfg.merge(cap_cfg) + @classmethod + def scan_one(cls, img, image_configs, fileset, global_cfg, local_cfg, lock): + img_cfg = Dataset.__sidecar_cfg(img, fileset, lock) + resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) + with lock: + image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) + + @classmethod + def scan_one_full(cls, img, image_configs, fileset, global_cfg, local_cfg, lock): + Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock) + img_cfg = Dataset.__sidecar_cfg(img, fileset, lock) + resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) + image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) + #print(f"{image_configs[img].main_prompts} {image_configs[img].tags} {image_configs[img].rating}") + + @classmethod def from_path(cls, data_root): # Create a visitor that maintains global config stack # and accumulates image configs as it traverses dataset + image_configs = {} def process_dir(files, parent_globals): + #pool = Pool(int(os.cpu_count()/2)) + lock = Lock() + fileset = {os.path.basename(f): f for f in files} global_cfg = parent_globals.merge(Dataset.__global_cfg(fileset)) local_cfg = Dataset.__local_cfg(fileset) for img in filter(is_image, files): - img_cfg = Dataset.__sidecar_cfg(img, fileset) - resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) - image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) + #pool.apply_async(Dataset.scan_one_full, args=(img, image_configs, fileset, global_cfg, local_cfg, lock)) + Dataset.scan_one_full(img, image_configs, fileset, global_cfg, local_cfg, lock) + #Dataset.scan_one(img, image_configs, fileset, global_cfg, local_cfg, lock) + #pool.close() + #pool.join() + # img_cfg = Dataset.__sidecar_cfg(img, fileset) + # resolved_cfg = ImageConfig.fold([global_cfg, local_cfg, img_cfg]) + # image_configs[img] = Dataset.__ensure_caption(resolved_cfg, img) + return global_cfg + time_start = time.time() walk_and_visit(data_root, process_dir, ImageConfig()) + time_end = time.time() + logging.info(f" ... walk_and_visit took {(time_end - time_start)/60:.2f} minutes and found {len(image_configs)} images") + return Dataset(image_configs) @classmethod @@ -212,45 +249,125 @@ class Dataset: continue image_configs[img] = cfg return Dataset(image_configs) - + + def get_one_image_train_item(self, image, aspects, profile=False) -> ImageTrainItem: + + + config = self.image_configs[image] + + tags = [] + tag_weights = [] + for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True): + tags.append(tag.value) + tag_weights.append(tag.weight) + use_weights = len(set(tag_weights)) > 1 + + try: + if profile: + profiler = cProfile.Profile() + import random + random_n = f"{random.randint(0,999):03d}" + profiler.enable() + caption = ImageCaption( + main_prompt=next(iter(config.main_prompts)), + rating=config.rating or 1.0, + tags=tags, + tag_weights=tag_weights, + max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH, + use_weights=use_weights) + if profile: + profiler.disable() + profiler.dump_stats(f'profile{random_n}.prof') + #exit() + + item = ImageTrainItem( + image=None, + caption=caption, + aspects=aspects, + pathname=os.path.abspath(image), + flip_p=config.flip_p or 0.0, + multiplier=config.multiply or 1.0, + cond_dropout=config.cond_dropout + ) + except Exception as e: + logging.error(f" *** Error preloading image or caption for: {image}, error: {e}") + raise e + + + return item + def image_train_items(self, aspects): + print(f" * using async loader") + run_profiler = False items = [] - for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True): - config = self.image_configs[image] + process_count = int(os.cpu_count()/2) + pool = Pool(process_count) + async_results = [] - if len(config.main_prompts) > 1: - logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}") + time_start = time.time() + with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar: + for image in self.image_configs: + async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects, run_profiler), callback=lambda _: pbar.update()) + async_results.append(async_result) + pool.close() + pool.join() - if len(config.main_prompts) < 1: - logging.warning(f" *** No main_prompts for image {image}") + for async_result in async_results: + result = async_result.get() + if result is not None: + # ImageTrainItem + items.append(result) + else: + raise ValueError(" *** image_train_items(): Async load item missing") + + + + time_end = time.time() + logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images") + return items - tags = [] - tag_weights = [] - for tag in sorted(config.tags, key=lambda x: x.weight or 1.0, reverse=True): - tags.append(tag.value) - tag_weights.append(tag.weight) - use_weights = len(set(tag_weights)) > 1 + def image_train_items_newish(self, aspects): + print(f" * using async loader") + items = [] + process_count = int(os.cpu_count()/2) + pool = Pool(process_count) - try: - caption = ImageCaption( - main_prompt=next(iter(config.main_prompts)), - rating=config.rating or 1.0, - tags=tags, - tag_weights=tag_weights, - max_target_length=config.max_caption_length or DEFAULT_MAX_CAPTION_LENGTH, - use_weights=use_weights) + time_start = time.time() + with tqdm(total=len(self.image_configs), desc=f"preloading {process_count}", dynamic_ncols=True) as pbar: + async_results = [] + + # run 1000 async tasks + for image in self.image_configs: + # profile the task + #cProfile.runctx('self.get_one(image,aspects)', globals(), locals(), 'profile.prof') + async_result = pool.apply_async(self.get_one_image_train_item, args=(image,aspects), callback=lambda _: pbar.update()) + async_results.append(async_result) + pool.close() + #pool.join() + print(f" * async pool closed") - item = ImageTrainItem( - image=None, - caption=caption, - aspects=aspects, - pathname=os.path.abspath(image), - flip_p=config.flip_p or 0.0, - multiplier=config.multiply or 1.0, - cond_dropout=config.cond_dropout - ) - items.append(item) - except Exception as e: - logging.error(f" *** Error preloading image or caption for: {image}, error: {e}") - raise e - return items \ No newline at end of file + for async_result in async_results: + result = async_result.get() + if result is not None: + # ImageTrainItem + items.append(result) + print(f"{result.pathname} {result.caption.main_prompt}") + else: + raise ValueError(" *** image_train_items(): Async load item missing") + + time_end = time.time() + logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images") + return items + + def image_train_items_old(self, aspects): + print(f" * using single threaded loader") + items = [] + + time_start = time.time() + with tqdm(total=len(self.image_configs), desc="preloading", dynamic_ncols=True) as pbar: + for image in self.image_configs: + items.append(self.get_one_image_train_item(image, aspects)) + pbar.update() + time_end = time.time() + logging.info(f" *** Preloading took {(time_end - time_start)/60:.2f} minutes and found {len(items)} images") + return items diff --git a/data/every_dream.py b/data/every_dream.py index ba9ea60..c3b97b2 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -1,5 +1,5 @@ """ -Copyright [2022] Victor C Hall +Copyright [2022-2023] Victor C Hall Licensed under the GNU Affero General Public License; You may not use this code except in compliance with the License. @@ -57,11 +57,11 @@ class EveryDreamBatch(Dataset): self.retain_contrast = retain_contrast self.shuffle_tags = shuffle_tags self.seed = seed - self.rated_dataset = rated_dataset - self.rated_dataset_dropout_target = rated_dataset_dropout_target + #self.rated_dataset = rated_dataset + #self.rated_dataset_dropout_target = rated_dataset_dropout_target # First epoch always trains on all images self.image_train_items = [] - self.__update_image_train_items(1.0) + self.__update_image_train_items() self.name = name num_images = len(self.image_train_items) @@ -69,13 +69,7 @@ class EveryDreamBatch(Dataset): def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 - - if self.rated_dataset: - dropout_fraction = (max_epochs - (epoch_n * self.rated_dataset_dropout_target)) / max_epochs - else: - dropout_fraction = 1.0 - - self.__update_image_train_items(dropout_fraction) + self.__update_image_train_items() def __len__(self): return len(self.image_train_items) @@ -140,8 +134,8 @@ class EveryDreamBatch(Dataset): return example - def __update_image_train_items(self, dropout_fraction: float): - self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction) + def __update_image_train_items(self): + self.image_train_items = self.data_loader.get_shuffled_image_buckets() def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader: dataloader = torch.utils.data.DataLoader( diff --git a/data/image_train_item.py b/data/image_train_item.py index 0346dda..741bb1f 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -56,6 +56,9 @@ class ImageCaption: if use_weights and len(tag_weights) > len(tags): self.__tag_weights = tag_weights[:len(tags)] + def __repr__(self) -> str: + return f"ImageCaption({self.__main_prompt}, {self.__rating}, {self.__tags}, {self.__tag_weights}, {self.__max_target_length}, {self.__use_weights})" + def rating(self) -> float: return self.__rating @@ -143,7 +146,6 @@ class ImageTrainItem: else: self.image = image self.image_size = image.size - self.target_size = None self.is_undersized = False self.error = None @@ -245,7 +247,7 @@ class ImageTrainItem: self.target_wh = None try: with PIL.Image.open(self.pathname) as image: - image = self._try_transpose(image, print_error=True).convert('RGB') + image = self._try_transpose(image, print_error=True) width, height = image.size image_aspect = width / height target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) diff --git a/train.py b/train.py index b0a6ca8..96ad494 100644 --- a/train.py +++ b/train.py @@ -241,8 +241,8 @@ def setup_args(args): args.clip_skip = max(min(4, args.clip_skip), 0) - if args.useadam8bit: - logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}") + #if args.useadam8bit: + # logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}") if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}") @@ -932,7 +932,7 @@ def main(args): if validator: validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target) - gc.collect() + #gc.collect() # end of epoch # end of training @@ -1011,12 +1011,12 @@ if __name__ == "__main__": argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)") argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random") argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") - argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead") + #argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") - argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") - argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") + #argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") + #argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15") # load CLI args to overwrite existing config args diff --git a/utils/fs_helpers.py b/utils/fs_helpers.py index 775ca0f..df016ae 100644 --- a/utils/fs_helpers.py +++ b/utils/fs_helpers.py @@ -25,7 +25,7 @@ def read_float(file): try: return float(read_text(file)) except Exception as e: - logging.warning(f" *** Could not parse '{data}' to float in file {file}: {e}") + logging.warning(f" *** Could not parse number to float in file {file}: {e}") import os @@ -48,4 +48,4 @@ def walk_and_visit(path, visit_fn, context=None): subcontext = visit_fn(files, context) for subdir in dirs: - walk_and_visit(subdir, visit_fn, subcontext) \ No newline at end of file + walk_and_visit(subdir, visit_fn, subcontext)