shuffle named batches while respecting and accounting for grad_accum

This commit is contained in:
damian 2023-06-07 18:07:37 +02:00
parent a8455b9427
commit 59fc9891d4
4 changed files with 152 additions and 13 deletions

View File

@ -27,6 +27,8 @@ from typing import Tuple, List
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image
from utils.first_fit_decreasing import first_fit_decreasing
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
class DataLoaderMultiAspect():
@ -37,9 +39,10 @@ class DataLoaderMultiAspect():
seed: random seed
batch_size: number of images per batch
"""
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1):
self.seed = seed
self.batch_size = batch_size
self.grad_accum = grad_accum
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
@ -106,6 +109,7 @@ class DataLoaderMultiAspect():
buckets = {}
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
@ -115,7 +119,7 @@ class DataLoaderMultiAspect():
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
# handle runts in "general" buckets by randomly duplicating items
# handle runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
@ -130,10 +134,38 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
items: list[ImageTrainItem] = []
for bucket in buckets:
items.extend(buckets[bucket])
def chunk(l: List, chunk_size) -> List:
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
return [l[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
def unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
# interleave buckets while trying to maximise shared grad accum chunks
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
filler_items = chunk(items_by_batch_id[DEFAULT_BATCH_ID], batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
# chunk by effective batch size
effective_batch_size = batch_size * grad_accum
chunks = chunk(items, effective_batch_size)
# shuffle, but preserve the last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < effective_batch_size:
last_chunk = chunks.pop(-1)
random.shuffle(chunks)
if last_chunk is not None:
chunks.append(last_chunk)
# un-chunk
items = unchunk(chunks)
return items

View File

@ -4,7 +4,7 @@ from utils.first_fit_decreasing import first_fit_decreasing
class TestFirstFitDecreasing(unittest.TestCase):
def test_basic(self):
def test_single_basic(self):
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
@ -20,3 +20,62 @@ class TestFirstFitDecreasing(unittest.TestCase):
input = [[1, 2, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3])
def test_multi_basic(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [2, 2, 1, 1])
def test_multi_complex(self):
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2])
input = [[1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3])
def test_filler_bucket(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [2, 2, 9, 9, 1, 1])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5])

View File

@ -552,6 +552,7 @@ def main(args):
image_train_items=image_train_items,
seed=seed,
batch_size=args.batch_size,
grad_accum=args.grad_accum
)
train_batch = EveryDreamBatch(

View File

@ -1,16 +1,63 @@
import copy
from typing import List
def first_fit_decreasing(input_list: List[List], batch_size: int) -> List:
def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List:
"""
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
lists end up in the same batch.
lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available.
@return a list of batches
@return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items
that are in the same input list end up in the same batch.
"""
def sort_by_length(items: List[List]):
return items.sort(key=lambda x: len(x), reverse=True)
def sort_by_length(items: List[List]) -> List[List]:
return sorted(items, key=lambda x: len(x))
remaining = list(input_list)
remaining = copy.deepcopy(input_items)
output = []
while remaining:
remaining = sort_by_length(remaining)
longest = remaining.pop()
if len(longest) == 0:
continue
if len(longest) >= batch_size:
output.append(longest[0:batch_size])
del longest[0:batch_size]
if len(longest)>0:
remaining.append(longest)
else:
# need to build this chunk by combining multiple
combined = copy.deepcopy(longest)
while True:
fill_length = batch_size - len(combined)
if fill_length == 0:
break
if len(remaining) == 0 and len(filler_items) == 0:
break
from_filler_bucket = filler_items[0:fill_length]
if len(from_filler_bucket) > 0:
del filler_items[0:fill_length]
combined.extend(from_filler_bucket)
continue
filler = next((r for r in remaining if len(r) <= fill_length), None)
if filler is not None:
remaining.remove(filler)
combined.extend(filler)
else:
# steal from the next longest
next_longest = remaining.pop()
combined.extend(next_longest[0:fill_length])
del next_longest[0:fill_length]
if len(next_longest) > 0:
remaining.append(next_longest)
output.append(combined)
output.append(filler_items)
return [i for o in output for i in o]