Merge pull request #193 from damian0815/feat_user_defined_batching

User defined batching
This commit is contained in:
Victor Hall 2023-06-10 13:08:19 -04:00 committed by GitHub
commit 6f64efaaaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 256 additions and 25 deletions

View File

@ -248,7 +248,7 @@ def __get_all_aspects():
]
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)

View File

@ -15,15 +15,17 @@ limitations under the License.
"""
import bisect
import logging
import os.path
from collections import defaultdict
import math
import copy
import random
from data.image_train_item import ImageTrainItem
from typing import List, Dict
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image
from utils.first_fit_decreasing import first_fit_decreasing
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
class DataLoaderMultiAspect():
@ -34,9 +36,10 @@ class DataLoaderMultiAspect():
seed: random seed
batch_size: number of images per batch
"""
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1):
self.seed = seed
self.batch_size = batch_size
self.grad_accum = grad_accum
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
@ -103,14 +106,18 @@ class DataLoaderMultiAspect():
buckets = {}
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
bucket_key = (image_caption_pair.batch_id,
image_caption_pair.target_wh[0],
image_caption_pair.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
# handle runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
@ -125,13 +132,19 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
items: list[ImageTrainItem] = []
for bucket in buckets:
items.extend(buckets[bucket])
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)
effective_batch_size = batch_size * grad_accum
items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images
@ -174,4 +187,53 @@ class DataLoaderMultiAspect():
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.ratings_summed.append(self.rating_overall_sum)
def chunk(l: List, chunk_size) -> List:
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
def unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
return items_by_batch_id
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
batch_size: int,
grad_accum: int) -> List[ImageTrainItem]:
# precondition: items_by_batch_id has no incomplete batches
assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values()))
# ensure we don't mix up aspect ratios by treating each chunk of batch_size images as
# a single unit to pass to first_fit_decreasing()
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
batch_size=grad_accum,
filler_items=filler_items)
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
return items
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
"""
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
"""
# chunk by effective batch size
chunks = chunk(l, chunk_size)
# preserve last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < chunk_size:
last_chunk = chunks.pop(-1)
randomizer.shuffle(chunks)
if last_chunk is not None:
chunks.append(last_chunk)
l = unchunk(chunks)
return l

View File

@ -1,10 +1,7 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from attrs import define, field
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import Iterable
@ -50,6 +47,7 @@ class ImageConfig:
rating: float = None
max_caption_length: int = None
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
batch_id: str = None
# Options
multiply: float = None
@ -70,6 +68,7 @@ class ImageConfig:
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
batch_id=overlay(other.batch_id, self.batch_id)
)
@classmethod
@ -84,6 +83,7 @@ class ImageConfig:
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"),
batch_id=data.get("batch_id")
)
# Alternatively parse from dedicated `caption` attribute
@ -168,6 +168,8 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
if 'local.yml' in fileset:
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
if 'batch_id.txt' in fileset:
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset:
@ -262,6 +264,7 @@ class Dataset:
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags,
batch_id=config.batch_id
)
items.append(item)
except Exception as e:

View File

@ -124,7 +124,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self,
def __init__(self,
image: PIL.Image,
caption: ImageCaption,
aspects: list[float],
@ -133,6 +133,7 @@ class ImageTrainItem:
multiplier: float=1.0,
cond_dropout=None,
shuffle_tags=False,
batch_id: str=None
):
self.caption = caption
self.aspects = aspects
@ -143,6 +144,8 @@ class ImageTrainItem:
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags
self.batch_id = batch_id or DEFAULT_BATCH_ID
self.target_wh = None
self.image_size = None
if image is None or len(image) == 0:
@ -351,3 +354,6 @@ class ImageTrainItem:
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image
DEFAULT_BATCH_ID = "default_batch"

View File

@ -151,6 +151,17 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
## Keeping images together (custom batching)
If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset).
To control this, put a file called `batch_id.txt` into a folder to give a unique name to the training data in this folder. For example, if you have a bunch of images of horses and you are trying to train them as a single concept, you can assign a unique name such as "my_horses" to these images by putting the word `my_horses` inside `batch_id.txt` in your folder with horse images.
> Note that because this creates extra aspect ratio buckets, you need to be very careful about correlating the number of images to your training batch size. Aim to have an exact multiple of `batch_size` images at each aspect ratio. For example, if your `batch_size` is 6 and you have images with aspect ratios 4:3, 3:4, and 9:16, you should add or delete images until you have an exact multiple of 6 images (i.e. 6, 12, 28, ...) for each aspect ratio. If you do not do this, the bucketer will duplicate images to fill up each aspect ratio bucket. You'll probably also want to use manual validation to prevent the validator from messing this up, too.
If you are using `.yaml` files for captioning, you can alternatively add a `batch_id: ` entry to either `local.yaml` or the individual images' `.yaml` files. Note that neither `.yaml` nor `batch_id.txt` files act recursively: they do not apply to subfolders.
# Stuff you probably don't need to mess with, but well here it is:

View File

@ -0,0 +1,81 @@
import unittest
from utils.first_fit_decreasing import first_fit_decreasing
class TestFirstFitDecreasing(unittest.TestCase):
def test_single_basic(self):
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3])
def test_multi_basic(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [2, 2, 1, 1])
def test_multi_complex(self):
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2])
input = [[1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3])
def test_filler_bucket(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [2, 2, 9, 9, 1, 1])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5])

View File

@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items])
def make_bucket_key(item):
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
ar_buckets = set(make_bucket_key(i) for i in items)
for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
@ -548,6 +552,7 @@ def main(args):
image_train_items=image_train_items,
seed=seed,
batch_size=args.batch_size,
grad_accum=args.grad_accum
)
train_batch = EveryDreamBatch(

View File

@ -0,0 +1,63 @@
import copy
from typing import List
def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List:
"""
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available.
@return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items
that are in the same input list end up in the same batch.
"""
def sort_by_length(items: List[List]) -> List[List]:
return sorted(items, key=lambda x: len(x))
remaining = input_items
output = []
while remaining:
remaining = sort_by_length(remaining)
longest = remaining.pop()
if len(longest) == 0:
continue
if len(longest) >= batch_size:
output.append(longest[0:batch_size])
del longest[0:batch_size]
if len(longest)>0:
remaining.append(longest)
else:
# need to build this chunk by combining multiple
combined = longest
while True:
fill_length = batch_size - len(combined)
if fill_length == 0:
break
if len(remaining) == 0 and len(filler_items) == 0:
break
from_filler_bucket = filler_items[0:fill_length]
if len(from_filler_bucket) > 0:
del filler_items[0:fill_length]
combined.extend(from_filler_bucket)
continue
filler = next((r for r in remaining if len(r) <= fill_length), None)
if filler is not None:
remaining.remove(filler)
combined.extend(filler)
else:
# steal from the next longest
next_longest = remaining.pop()
combined.extend(next_longest[0:fill_length])
del next_longest[0:fill_length]
if len(next_longest) > 0:
remaining.append(next_longest)
output.append(combined)
output.append(filler_items)
return [i for o in output for i in o]