From 4e37200ddad89d13f8458281ace2b116899a300d Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 8 Feb 2023 11:28:45 +0100 Subject: [PATCH 1/3] fix multiplier issues with validation and refactor validation logic --- data/data_loader.py | 40 ++++++++++++---------------------- data/every_dream.py | 6 ----- data/every_dream_validation.py | 20 +++++++++++++---- data/image_train_item.py | 2 +- data/resolver.py | 12 ++-------- train.py | 17 ++++++++++----- 6 files changed, 44 insertions(+), 53 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 5fe4ba6..557e9d7 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ import bisect +import logging from functools import reduce import math import copy @@ -41,40 +42,27 @@ class DataLoaderMultiAspect(): self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] self.__update_rating_sums() + count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data]) + if count_including_multipliers > len(self.prepared_train_data): + logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)") + else: + logging.info(f" * DLMA initialized with {len(image_train_items)} items") + def __pick_multiplied_set(self, randomizer): """ Deals with multiply.txt whole and fractional numbers """ - #print(f"Picking multiplied set from {len(self.prepared_train_data)}") - data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property - epoch_size = len(self.prepared_train_data) picked_images = [] - - # add by whole number part first and decrement multiplier in copy - for iti in data_copy: - #print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}") - while iti.multiplier >= 1.0: + for iti in self.prepared_train_data: + multiplier = iti.multiplier + while multiplier >= 1: + picked_images.append(iti) + multiplier -= 1 + # deal with fractional remainder + if multiplier > randomizer.uniform(0, 1): picked_images.append(iti) - #print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}") - iti.multiplier -= 1.0 - remaining = epoch_size - len(picked_images) - - assert remaining >= 0, "Something went wrong with the multiplier calculation" - - # add by remaining fractional numbers by random chance - while remaining > 0: - for iti in data_copy: - if randomizer.uniform(0.0, 1.0) < iti.multiplier: - #print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}") - picked_images.append(iti) - remaining -= 1 - iti.multiplier = 0.0 - if remaining <= 0: - break - - del data_copy return picked_images def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0) -> list[ImageTrainItem]: diff --git a/data/every_dream.py b/data/every_dream.py index 30f10aa..06cdacc 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -65,12 +65,6 @@ class EveryDreamBatch(Dataset): num_images = len(self.image_train_items) logging.info(f" ** Dataset '{name}': {num_images / self.batch_size:.0f} batches, num_images: {num_images}, batch_size: {self.batch_size}") - def get_random_split(self, split_proportion: float, remove_from_dataset: bool=False) -> list[ImageTrainItem]: - items = self.data_loader.get_random_split(split_proportion, remove_from_dataset) - self.__update_image_train_items(1.0) - return items - - def shuffle(self, epoch_n: int, max_epochs: int): self.seed += 1 diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 1302cc8..4f2c083 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -1,7 +1,9 @@ +import copy import json +import logging import math import random -from typing import Callable, Any, Optional +from typing import Callable, Any, Optional, Generator from argparse import Namespace import torch @@ -29,22 +31,28 @@ def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch remaining_items = list(items_copy[split_item_count:]) return split_items, remaining_items +def disable_multiplier_and_flip(items: list[ImageTrainItem]) -> Generator[ImageTrainItem, None, None]: + for i in items: + yield ImageTrainItem(image=i.image, caption=i.caption, aspects=i.aspects, pathname=i.pathname, flip_p=0, multiplier=1) class EveryDreamValidator: def __init__(self, val_config_path: Optional[str], default_batch_size: int, + resolution: int, log_writer: SummaryWriter): self.val_dataloader = None self.train_overlapping_dataloader = None self.log_writer = log_writer + self.resolution = resolution self.config = { 'batch_size': default_batch_size, 'every_n_epochs': 1, 'seed': 555, + 'validate_training': True, 'val_split_mode': 'automatic', 'val_split_proportion': 0.15, @@ -120,21 +128,24 @@ class EveryDreamValidator: def _build_val_dataloader_if_required(self, image_train_items: list[ImageTrainItem], tokenizer)\ -> tuple[Optional[torch.utils.data.DataLoader], list[ImageTrainItem]]: - val_split_mode = self.config['val_split_mode'] + val_split_mode = self.config['val_split_mode'] if self.config['validate_training'] else None val_split_proportion = self.config['val_split_proportion'] remaining_train_items = image_train_items - if val_split_mode == 'none': + if val_split_mode is None or val_split_mode == 'none': return None, image_train_items elif val_split_mode == 'automatic': val_items, remaining_train_items = get_random_split(image_train_items, val_split_proportion, batch_size=self.batch_size) + val_items = list(disable_multiplier_and_flip(val_items)) + logging.info(f" * Removed {len(val_items)} images from the training set to use for validation") elif val_split_mode == 'manual': args = Namespace( - aspects=aspects.get_aspect_buckets(512), + aspects=aspects.get_aspect_buckets(self.resolution), flip_p=0.0, seed=self.seed, ) val_data_root = self.config['val_data_root'] val_items = resolver.resolve_root(val_data_root, args) + logging.info(f" * Loaded {len(val_items)} validation images from {val_data_root}") else: raise ValueError(f"Unrecognized validation split mode '{val_split_mode}'") val_ed_batch = self._build_ed_batch(val_items, batch_size=self.batch_size, tokenizer=tokenizer, name='val') @@ -149,6 +160,7 @@ class EveryDreamValidator: stabilize_split_proportion = self.config['stabilize_split_proportion'] stabilize_items, _ = get_random_split(image_train_items, stabilize_split_proportion, batch_size=self.batch_size) + stabilize_items = list(disable_multiplier_and_flip(stabilize_items)) stabilize_ed_batch = self._build_ed_batch(stabilize_items, batch_size=self.batch_size, tokenizer=tokenizer, name='stabilize-train') stabilize_dataloader = build_torch_dataloader(stabilize_ed_batch, batch_size=self.batch_size) diff --git a/data/image_train_item.py b/data/image_train_item.py index de72374..8e88612 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -263,7 +263,7 @@ class ImageTrainItem: self.multiplier = multiplier self.image_size = None - if image is None: + if image is None or len(image) == 0: self.image = [] else: self.image = image diff --git a/data/resolver.py b/data/resolver.py index fbcd076..e66a3b8 100644 --- a/data/resolver.py +++ b/data/resolver.py @@ -128,7 +128,7 @@ class DirectoryResolver(DataResolver): with open(multiply_txt_path, 'r') as f: val = float(f.read().strip()) multipliers[current_dir] = val - logging.info(f" * DLMA multiply.txt in {current_dir} set to {val}") + logging.info(f" - multiply.txt in '{current_dir}' set to {val}") except Exception as e: logging.warning(f" * {Fore.LIGHTYELLOW_EX}Error trying to read multiply.txt for {current_dir}: {Style.RESET_ALL}{e}") multipliers[current_dir] = 1.0 @@ -137,16 +137,8 @@ class DirectoryResolver(DataResolver): caption = ImageCaption.resolve(pathname) item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) - - cur_file_multiplier = multipliers[current_dir] + items.append(item) - while cur_file_multiplier >= 1.0: - items.append(item) - cur_file_multiplier -= 1 - - if cur_file_multiplier > 0: - if randomizer.random() < cur_file_multiplier: - items.append(item) return items @staticmethod diff --git a/train.py b/train.py index dcd5ed4..9ef5d32 100644 --- a/train.py +++ b/train.py @@ -57,7 +57,8 @@ from data.every_dream_validation import EveryDreamValidator from data.image_train_item import ImageTrainItem from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter -from utils.gpu import GPU +if torch.cuda.is_available(): + from utils.gpu import GPU import data.aspects as aspects import data.resolver as resolver @@ -326,8 +327,7 @@ def resolve_image_train_items(args: argparse.Namespace, log_folder: str) -> list # Remove erroneous items image_train_items = [item for item in resolved_items if item.error is None] - - print (f" * DLMA: {len(image_train_items)} images loaded from {len(image_paths)} files") + print (f" * Found {len(image_paths)} files in '{args.data_root}'") return image_train_items @@ -620,9 +620,13 @@ def main(args): image_train_items = resolve_image_train_items(args, log_folder) - #validator = EveryDreamValidator(args.validation_config, log_writer=log_writer, default_batch_size=args.batch_size) + validator = EveryDreamValidator(args.validation_config, + default_batch_size=args.batch_size, + resolution=args.resolution, + log_writer=log_writer, + ) # the validation dataset may need to steal some items from image_train_items - #image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) + image_train_items = validator.prepare_validation_splits(image_train_items, tokenizer=tokenizer) data_loader = DataLoaderMultiAspect( image_train_items=image_train_items, @@ -940,7 +944,7 @@ def main(args): log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) # validate - #validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) gc.collect() # end of epoch @@ -1021,6 +1025,7 @@ if __name__ == "__main__": argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets") argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!") argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") + argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Uses defaults if omitted.") argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)") argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs") argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)") From a7b00e9ef3bf739c3f2026c02ef056d094e070af Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 8 Feb 2023 13:46:58 +0100 Subject: [PATCH 2/3] fix multiplier logic --- data/data_loader.py | 52 ++++++++++++++++++++++++++++----------------- train.py | 33 +++++++++++++++------------- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 557e9d7..28fcd00 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -39,14 +39,15 @@ class DataLoaderMultiAspect(): self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) + self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) + if self.epoch_size != len(self.prepared_train_data): + logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.") + else: + logging.info(f" * DLMA initialized with {len(image_train_items)} images.") + self.rating_overall_sum: float = 0.0 self.ratings_summed: list[float] = [] self.__update_rating_sums() - count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data]) - if count_including_multipliers > len(self.prepared_train_data): - logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)") - else: - logging.info(f" * DLMA initialized with {len(image_train_items)} items") def __pick_multiplied_set(self, randomizer): @@ -54,14 +55,28 @@ class DataLoaderMultiAspect(): Deals with multiply.txt whole and fractional numbers """ picked_images = [] + fractional_images = [] for iti in self.prepared_train_data: multiplier = iti.multiplier while multiplier >= 1: picked_images.append(iti) multiplier -= 1 - # deal with fractional remainder - if multiplier > randomizer.uniform(0, 1): + # fractional remainders must be dealt with separately + if multiplier > 0: + fractional_images.append((iti, multiplier)) + + target_epoch_size = self.epoch_size + while len(picked_images) < target_epoch_size and len(fractional_images) > 0: + # cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier + iti, multiplier = fractional_images.pop(0) + if randomizer.uniform(0, 1) < multiplier: + # shift it over to picked_images picked_images.append(iti) + else: + # put it back and move on to the next + fractional_images.append((iti, multiplier)) + + assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers" return picked_images @@ -98,20 +113,19 @@ class DataLoaderMultiAspect(): buckets[(target_wh[0],target_wh[1])] = [] buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) - if len(buckets) > 1: - for bucket in buckets: - truncate_count = len(buckets[bucket]) % batch_size - if truncate_count > 0: - runt_bucket = buckets[bucket][-truncate_count:] - for item in runt_bucket: - item.runt_size = truncate_count - while len(runt_bucket) < batch_size: - runt_bucket.append(random.choice(runt_bucket)) + for bucket in buckets: + truncate_count = len(buckets[bucket]) % batch_size + if truncate_count > 0: + runt_bucket = buckets[bucket][-truncate_count:] + for item in runt_bucket: + item.runt_size = truncate_count + while len(runt_bucket) < batch_size: + runt_bucket.append(random.choice(runt_bucket)) - current_bucket_size = len(buckets[bucket]) + current_bucket_size = len(buckets[bucket]) - buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] - buckets[bucket].extend(runt_bucket) + buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] + buckets[bucket].extend(runt_bucket) # flatten the buckets items: list[ImageTrainItem] = [] diff --git a/train.py b/train.py index 9ef5d32..32f35ba 100644 --- a/train.py +++ b/train.py @@ -160,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs): """ updates the vram usage for the epoch """ - gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() - log_writer.add_scalar("performance/vram", gpu_used_mem, global_step) - epoch_mem_color = Style.RESET_ALL - if gpu_used_mem > 0.93 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTRED_EX - elif gpu_used_mem > 0.85 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTYELLOW_EX - elif gpu_used_mem > 0.7 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTGREEN_EX - elif gpu_used_mem < 0.5 * gpu_total_mem: - epoch_mem_color = Fore.LIGHTBLUE_EX + if gpu is not None: + gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() + log_writer.add_scalar("performance/vram", gpu_used_mem, global_step) + epoch_mem_color = Style.RESET_ALL + if gpu_used_mem > 0.93 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTRED_EX + elif gpu_used_mem > 0.85 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTYELLOW_EX + elif gpu_used_mem > 0.7 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTGREEN_EX + elif gpu_used_mem < 0.5 * gpu_total_mem: + epoch_mem_color = Fore.LIGHTBLUE_EX - if logs is not None: - epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") + if logs is not None: + epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}") def set_args_12gb(args): @@ -372,6 +373,7 @@ def main(args): else: logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.") device = 'cpu' + gpu = None log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}") @@ -714,8 +716,9 @@ def main(args): if not os.path.exists(f"{log_folder}/samples/"): os.makedirs(f"{log_folder}/samples/") - gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() - logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") + if gpu is not None: + gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory() + logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB") logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes") logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs") From 19347bcaa81598507a02b687fb33d640aa3744f8 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 8 Feb 2023 14:15:54 +0100 Subject: [PATCH 3/3] make fractional multiplier logic apply per-directory --- data/data_loader.py | 39 +++++++++++++++++---------------------- train.py | 1 + 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 28fcd00..d97e38e 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -15,13 +15,13 @@ limitations under the License. """ import bisect import logging -from functools import reduce +import os.path +from collections import defaultdict import math -import copy import random -from data.image_train_item import ImageTrainItem, ImageCaption -import PIL +from data.image_train_item import ImageTrainItem +import PIL.Image PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default @@ -39,9 +39,9 @@ class DataLoaderMultiAspect(): self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) - self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) - if self.epoch_size != len(self.prepared_train_data): - logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.") + expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) + if expected_epoch_size != len(self.prepared_train_data): + logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.") else: logging.info(f" * DLMA initialized with {len(image_train_items)} images.") @@ -50,12 +50,12 @@ class DataLoaderMultiAspect(): self.__update_rating_sums() - def __pick_multiplied_set(self, randomizer): + def __pick_multiplied_set(self, randomizer: random.Random): """ Deals with multiply.txt whole and fractional numbers """ picked_images = [] - fractional_images = [] + fractional_images_per_directory = defaultdict(list[ImageTrainItem]) for iti in self.prepared_train_data: multiplier = iti.multiplier while multiplier >= 1: @@ -63,20 +63,15 @@ class DataLoaderMultiAspect(): multiplier -= 1 # fractional remainders must be dealt with separately if multiplier > 0: - fractional_images.append((iti, multiplier)) + directory = os.path.dirname(iti.pathname) + fractional_images_per_directory[directory].append(iti) - target_epoch_size = self.epoch_size - while len(picked_images) < target_epoch_size and len(fractional_images) > 0: - # cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier - iti, multiplier = fractional_images.pop(0) - if randomizer.uniform(0, 1) < multiplier: - # shift it over to picked_images - picked_images.append(iti) - else: - # put it back and move on to the next - fractional_images.append((iti, multiplier)) - - assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers" + # resolve fractional parts per-directory + for _, fractional_items in fractional_images_per_directory.items(): + randomizer.shuffle(fractional_items) + multiplier = fractional_items[0].multiplier % 1.0 + count_to_take = math.ceil(multiplier * len(fractional_items)) + picked_images.extend(fractional_items[:count_to_take]) return picked_images diff --git a/train.py b/train.py index 32f35ba..c63c471 100644 --- a/train.py +++ b/train.py @@ -550,6 +550,7 @@ def main(args): except Exception as e: traceback.print_exc() logging.error(" * Failed to load checkpoint *") + raise if args.gradient_checkpointing: unet.enable_gradient_checkpointing()