diff --git a/data/aspects.py b/data/aspects.py index 42265a1..e5d2bd7 100644 --- a/data/aspects.py +++ b/data/aspects.py @@ -248,7 +248,7 @@ def __get_all_aspects(): ] -def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]: +def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]: def farey_aspect_ratio_pair(x: float, max_denominator_value: int): if x <= 1: return farey_aspect_ratio_pair_lt1(x, max_denominator_value) diff --git a/data/data_loader.py b/data/data_loader.py index 9d9f7c4..1ab9b67 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -15,15 +15,17 @@ limitations under the License. """ import bisect import logging -import os.path -from collections import defaultdict import math import copy import random -from data.image_train_item import ImageTrainItem +from typing import List, Dict + +from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID import PIL.Image +from utils.first_fit_decreasing import first_fit_decreasing + PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default class DataLoaderMultiAspect(): @@ -34,9 +36,10 @@ class DataLoaderMultiAspect(): seed: random seed batch_size: number of images per batch """ - def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1): + def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1): self.seed = seed self.batch_size = batch_size + self.grad_accum = grad_accum self.prepared_train_data = image_train_items random.Random(self.seed).shuffle(self.prepared_train_data) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) @@ -103,14 +106,18 @@ class DataLoaderMultiAspect(): buckets = {} batch_size = self.batch_size + grad_accum = self.grad_accum + for image_caption_pair in picked_images: image_caption_pair.runt_size = 0 - target_wh = image_caption_pair.target_wh - - if (target_wh[0],target_wh[1]) not in buckets: - buckets[(target_wh[0],target_wh[1])] = [] - buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) + bucket_key = (image_caption_pair.batch_id, + image_caption_pair.target_wh[0], + image_caption_pair.target_wh[1]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + buckets[bucket_key].append(image_caption_pair) + # handle runts by randomly duplicating items for bucket in buckets: truncate_count = len(buckets[bucket]) % batch_size if truncate_count > 0: @@ -125,13 +132,19 @@ class DataLoaderMultiAspect(): buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] buckets[bucket].extend(runt_bucket) - # flatten the buckets - items: list[ImageTrainItem] = [] - for bucket in buckets: - items.extend(buckets[bucket]) + # handle batch_id + # unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID. + items_by_batch_id = collapse_buckets_by_batch_id(buckets) + items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id, + batch_size=batch_size, + grad_accum=grad_accum) + + effective_batch_size = batch_size * grad_accum + items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer) return items + def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]: """ Picks a random subset of all images @@ -174,4 +187,53 @@ class DataLoaderMultiAspect(): self.ratings_summed: list[float] = [] for item in self.prepared_train_data: self.rating_overall_sum += item.caption.rating() - self.ratings_summed.append(self.rating_overall_sum) \ No newline at end of file + self.ratings_summed.append(self.rating_overall_sum) + + +def chunk(l: List, chunk_size) -> List: + num_chunks = int(math.ceil(float(len(l)) / chunk_size)) + return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)] + +def unchunk(chunked_list: List): + return [i for c in chunked_list for i in c] + +def collapse_buckets_by_batch_id(buckets: Dict) -> Dict: + batch_ids = [k[0] for k in buckets.keys()] + items_by_batch_id = {} + for batch_id in batch_ids: + items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id]) + return items_by_batch_id + +def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]], + batch_size: int, + grad_accum: int) -> List[ImageTrainItem]: + # precondition: items_by_batch_id has no incomplete batches + assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values())) + # ensure we don't mix up aspect ratios by treating each chunk of batch_size images as + # a single unit to pass to first_fit_decreasing() + filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size) + custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] + neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, + batch_size=grad_accum, + filler_items=filler_items) + + items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items) + return items + +def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List: + """ + Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk. + If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk) + """ + + # chunk by effective batch size + chunks = chunk(l, chunk_size) + # preserve last chunk as last if it is incomplete + last_chunk = None + if len(chunks[-1]) < chunk_size: + last_chunk = chunks.pop(-1) + randomizer.shuffle(chunks) + if last_chunk is not None: + chunks.append(last_chunk) + l = unchunk(chunks) + return l diff --git a/data/dataset.py b/data/dataset.py index c8fc2f3..a478a46 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,10 +1,7 @@ -import os -import logging import yaml import json -from functools import total_ordering -from attrs import define, field, Factory +from attrs import define, field from data.image_train_item import ImageCaption, ImageTrainItem from utils.fs_helpers import * from typing import Iterable @@ -50,6 +47,7 @@ class ImageConfig: rating: float = None max_caption_length: int = None tags: dict[Tag, None] = field(factory=dict, converter=safe_set) + batch_id: str = None # Options multiply: float = None @@ -70,6 +68,7 @@ class ImageConfig: cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), + batch_id=overlay(other.batch_id, self.batch_id) ) @classmethod @@ -84,6 +83,7 @@ class ImageConfig: cond_dropout=data.get("cond_dropout"), flip_p=data.get("flip_p"), shuffle_tags=data.get("shuffle_tags"), + batch_id=data.get("batch_id") ) # Alternatively parse from dedicated `caption` attribute @@ -168,6 +168,8 @@ class Dataset: cfgs.append(ImageConfig.from_file(fileset['local.yaml'])) if 'local.yml' in fileset: cfgs.append(ImageConfig.from_file(fileset['local.yml'])) + if 'batch_id.txt' in fileset: + cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt']))) result = ImageConfig.fold(cfgs) if 'shuffle_tags.txt' in fileset: @@ -262,6 +264,7 @@ class Dataset: multiplier=config.multiply or 1.0, cond_dropout=config.cond_dropout, shuffle_tags=config.shuffle_tags, + batch_id=config.batch_id ) items.append(item) except Exception as e: diff --git a/data/image_train_item.py b/data/image_train_item.py index 3623b2b..368eac7 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -124,7 +124,7 @@ class ImageTrainItem: flip_p: probability of flipping image (0.0 to 1.0) rating: the relative rating of the images. The rating is measured in comparison to the other images. """ - def __init__(self, + def __init__(self, image: PIL.Image, caption: ImageCaption, aspects: list[float], @@ -133,6 +133,7 @@ class ImageTrainItem: multiplier: float=1.0, cond_dropout=None, shuffle_tags=False, + batch_id: str=None ): self.caption = caption self.aspects = aspects @@ -143,6 +144,8 @@ class ImageTrainItem: self.multiplier = multiplier self.cond_dropout = cond_dropout self.shuffle_tags = shuffle_tags + self.batch_id = batch_id or DEFAULT_BATCH_ID + self.target_wh = None self.image_size = None if image is None or len(image) == 0: @@ -351,3 +354,6 @@ class ImageTrainItem: image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy)) return image + + +DEFAULT_BATCH_ID = "default_batch" diff --git a/doc/ADVANCED_TWEAKING.md b/doc/ADVANCED_TWEAKING.md index c3ac9f1..b8afccc 100644 --- a/doc/ADVANCED_TWEAKING.md +++ b/doc/ADVANCED_TWEAKING.md @@ -151,6 +151,17 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time. +## Keeping images together (custom batching) + +If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset). + +To control this, put a file called `batch_id.txt` into a folder to give a unique name to the training data in this folder. For example, if you have a bunch of images of horses and you are trying to train them as a single concept, you can assign a unique name such as "my_horses" to these images by putting the word `my_horses` inside `batch_id.txt` in your folder with horse images. + +> Note that because this creates extra aspect ratio buckets, you need to be very careful about correlating the number of images to your training batch size. Aim to have an exact multiple of `batch_size` images at each aspect ratio. For example, if your `batch_size` is 6 and you have images with aspect ratios 4:3, 3:4, and 9:16, you should add or delete images until you have an exact multiple of 6 images (i.e. 6, 12, 28, ...) for each aspect ratio. If you do not do this, the bucketer will duplicate images to fill up each aspect ratio bucket. You'll probably also want to use manual validation to prevent the validator from messing this up, too. + +If you are using `.yaml` files for captioning, you can alternatively add a `batch_id: ` entry to either `local.yaml` or the individual images' `.yaml` files. Note that neither `.yaml` nor `batch_id.txt` files act recursively: they do not apply to subfolders. + + # Stuff you probably don't need to mess with, but well here it is: diff --git a/test/test_first_fit_decreasing.py b/test/test_first_fit_decreasing.py new file mode 100644 index 0000000..f64b916 --- /dev/null +++ b/test/test_first_fit_decreasing.py @@ -0,0 +1,81 @@ +import unittest + +from utils.first_fit_decreasing import first_fit_decreasing + +class TestFirstFitDecreasing(unittest.TestCase): + + def test_single_basic(self): + input = [[1, 2, 3, 4, 5, 6]] + output = first_fit_decreasing(input, batch_size=2) + self.assertEqual(output, [1, 2, 3, 4, 5, 6]) + + input = [[1, 2, 3, 4, 5, 6]] + output = first_fit_decreasing(input, batch_size=3) + self.assertEqual(output, [1, 2, 3, 4, 5, 6]) + + input = [[1, 2, 3, 4, 5, 6]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [1, 2, 3, 4, 5, 6]) + + input = [[1, 2, 3]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [1, 2, 3]) + + def test_multi_basic(self): + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=2) + self.assertEqual(output, [1, 1, 1, 1, 2, 2]) + + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=3) + self.assertEqual(output, [1, 1, 1, 2, 2, 1]) + + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [1, 1, 1, 1, 2, 2]) + + input = [[1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [2, 2, 1, 1]) + + def test_multi_complex(self): + input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]] + output = first_fit_decreasing(input, batch_size=2) + self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3]) + + input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]] + output = first_fit_decreasing(input, batch_size=3) + self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1]) + + input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2]) + + input = [[1, 1], [2, 2], [3, 3, 3]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2]) + + input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]] + output = first_fit_decreasing(input, batch_size=4) + self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3]) + + def test_filler_bucket(self): + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9]) + self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9]) + + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9]) + self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9]) + + input = [[1, 1, 1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9]) + self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9]) + + input = [[1, 1], [2, 2]] + output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9]) + self.assertEqual(output, [2, 2, 9, 9, 1, 1]) + + input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]] + output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9]) + self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5]) diff --git a/train.py b/train.py index 257bf12..41edc53 100644 --- a/train.py +++ b/train.py @@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect from data.every_dream import EveryDreamBatch, build_torch_dataloader from data.every_dream_validation import EveryDreamValidator -from data.image_train_item import ImageTrainItem +from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID from utils.huggingface_downloader import try_download_model_from_hf from utils.convert_diff_to_ckpt import convert as converter from utils.isolate_rng import isolate_rng @@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem # at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0 warn_bucket_dupe_ratio = 0.5 - ar_buckets = set([tuple(i.target_wh) for i in items]) + def make_bucket_key(item): + return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1])) + + ar_buckets = set(make_bucket_key(i) for i in items) for ar_bucket in ar_buckets: - count = len([i for i in items if tuple(i.target_wh) == ar_bucket]) + count = len([i for i in items if make_bucket_key(i) == ar_bucket]) runt_size = batch_size - (count % batch_size) bucket_dupe_ratio = runt_size / count if bucket_dupe_ratio > warn_bucket_dupe_ratio: - aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket) + aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2])) aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}" + batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'" effective_multiplier = round(1 + bucket_dupe_ratio, 1) logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} " f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier " f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or " - f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.") + f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.") def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]: logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}") @@ -548,6 +552,7 @@ def main(args): image_train_items=image_train_items, seed=seed, batch_size=args.batch_size, + grad_accum=args.grad_accum ) train_batch = EveryDreamBatch( diff --git a/utils/first_fit_decreasing.py b/utils/first_fit_decreasing.py new file mode 100644 index 0000000..e323637 --- /dev/null +++ b/utils/first_fit_decreasing.py @@ -0,0 +1,63 @@ +import copy +from typing import List + +def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List: + """ + Given as input a list of lists, batch the items so that as much as possible the members of each of the original + lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available. + + @return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items + that are in the same input list end up in the same batch. + """ + + def sort_by_length(items: List[List]) -> List[List]: + return sorted(items, key=lambda x: len(x)) + + remaining = input_items + output = [] + while remaining: + remaining = sort_by_length(remaining) + longest = remaining.pop() + if len(longest) == 0: + continue + if len(longest) >= batch_size: + output.append(longest[0:batch_size]) + del longest[0:batch_size] + if len(longest)>0: + remaining.append(longest) + else: + # need to build this chunk by combining multiple + combined = longest + while True: + fill_length = batch_size - len(combined) + if fill_length == 0: + break + + if len(remaining) == 0 and len(filler_items) == 0: + break + + from_filler_bucket = filler_items[0:fill_length] + if len(from_filler_bucket) > 0: + del filler_items[0:fill_length] + combined.extend(from_filler_bucket) + continue + + filler = next((r for r in remaining if len(r) <= fill_length), None) + if filler is not None: + remaining.remove(filler) + combined.extend(filler) + else: + # steal from the next longest + next_longest = remaining.pop() + combined.extend(next_longest[0:fill_length]) + del next_longest[0:fill_length] + if len(next_longest) > 0: + remaining.append(next_longest) + output.append(combined) + + output.append(filler_items) + return [i for o in output for i in o] + + + +