shuffle named batches while respecting and accounting for grad_accum
This commit is contained in:
parent
a8455b9427
commit
59fc9891d4
|
@ -27,6 +27,8 @@ from typing import Tuple, List
|
|||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||
import PIL.Image
|
||||
|
||||
from utils.first_fit_decreasing import first_fit_decreasing
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
||||
class DataLoaderMultiAspect():
|
||||
|
@ -37,9 +39,10 @@ class DataLoaderMultiAspect():
|
|||
seed: random seed
|
||||
batch_size: number of images per batch
|
||||
"""
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
|
||||
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1):
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum = grad_accum
|
||||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
|
@ -106,6 +109,7 @@ class DataLoaderMultiAspect():
|
|||
|
||||
buckets = {}
|
||||
batch_size = self.batch_size
|
||||
grad_accum = self.grad_accum
|
||||
|
||||
for image_caption_pair in picked_images:
|
||||
image_caption_pair.runt_size = 0
|
||||
|
@ -115,7 +119,7 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(image_caption_pair)
|
||||
|
||||
# handle runts in "general" buckets by randomly duplicating items
|
||||
# handle runts by randomly duplicating items
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
|
@ -130,10 +134,38 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
items: list[ImageTrainItem] = []
|
||||
for bucket in buckets:
|
||||
items.extend(buckets[bucket])
|
||||
def chunk(l: List, chunk_size) -> List:
|
||||
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
|
||||
return [l[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
|
||||
|
||||
def unchunk(chunked_list: List):
|
||||
return [i for c in chunked_list for i in c]
|
||||
|
||||
# interleave buckets while trying to maximise shared grad accum chunks
|
||||
batch_ids = [k[0] for k in buckets.keys()]
|
||||
items_by_batch_id = {}
|
||||
for batch_id in batch_ids:
|
||||
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
||||
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
||||
filler_items = chunk(items_by_batch_id[DEFAULT_BATCH_ID], batch_size)
|
||||
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
||||
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
|
||||
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)
|
||||
|
||||
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
||||
|
||||
# chunk by effective batch size
|
||||
effective_batch_size = batch_size * grad_accum
|
||||
chunks = chunk(items, effective_batch_size)
|
||||
# shuffle, but preserve the last chunk as last if it is incomplete
|
||||
last_chunk = None
|
||||
if len(chunks[-1]) < effective_batch_size:
|
||||
last_chunk = chunks.pop(-1)
|
||||
random.shuffle(chunks)
|
||||
if last_chunk is not None:
|
||||
chunks.append(last_chunk)
|
||||
# un-chunk
|
||||
items = unchunk(chunks)
|
||||
|
||||
return items
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from utils.first_fit_decreasing import first_fit_decreasing
|
|||
|
||||
class TestFirstFitDecreasing(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
def test_single_basic(self):
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
@ -20,3 +20,62 @@ class TestFirstFitDecreasing(unittest.TestCase):
|
|||
input = [[1, 2, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 2, 3])
|
||||
|
||||
def test_multi_basic(self):
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 1, 1, 2, 2, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
|
||||
|
||||
input = [[1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [2, 2, 1, 1])
|
||||
|
||||
def test_multi_complex(self):
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2])
|
||||
|
||||
input = [[1, 1], [2, 2], [3, 3, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3])
|
||||
|
||||
def test_filler_bucket(self):
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
|
||||
|
||||
input = [[1, 1], [2, 2]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [2, 2, 9, 9, 1, 1])
|
||||
|
||||
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
|
||||
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
|
||||
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5])
|
||||
|
|
1
train.py
1
train.py
|
@ -552,6 +552,7 @@ def main(args):
|
|||
image_train_items=image_train_items,
|
||||
seed=seed,
|
||||
batch_size=args.batch_size,
|
||||
grad_accum=args.grad_accum
|
||||
)
|
||||
|
||||
train_batch = EveryDreamBatch(
|
||||
|
|
|
@ -1,16 +1,63 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
def first_fit_decreasing(input_list: List[List], batch_size: int) -> List:
|
||||
def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List:
|
||||
"""
|
||||
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
|
||||
lists end up in the same batch.
|
||||
lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available.
|
||||
|
||||
@return a list of batches
|
||||
@return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items
|
||||
that are in the same input list end up in the same batch.
|
||||
"""
|
||||
|
||||
def sort_by_length(items: List[List]):
|
||||
return items.sort(key=lambda x: len(x), reverse=True)
|
||||
def sort_by_length(items: List[List]) -> List[List]:
|
||||
return sorted(items, key=lambda x: len(x))
|
||||
|
||||
remaining = list(input_list)
|
||||
remaining = copy.deepcopy(input_items)
|
||||
output = []
|
||||
while remaining:
|
||||
remaining = sort_by_length(remaining)
|
||||
longest = remaining.pop()
|
||||
if len(longest) == 0:
|
||||
continue
|
||||
if len(longest) >= batch_size:
|
||||
output.append(longest[0:batch_size])
|
||||
del longest[0:batch_size]
|
||||
if len(longest)>0:
|
||||
remaining.append(longest)
|
||||
else:
|
||||
# need to build this chunk by combining multiple
|
||||
combined = copy.deepcopy(longest)
|
||||
while True:
|
||||
fill_length = batch_size - len(combined)
|
||||
if fill_length == 0:
|
||||
break
|
||||
|
||||
if len(remaining) == 0 and len(filler_items) == 0:
|
||||
break
|
||||
|
||||
from_filler_bucket = filler_items[0:fill_length]
|
||||
if len(from_filler_bucket) > 0:
|
||||
del filler_items[0:fill_length]
|
||||
combined.extend(from_filler_bucket)
|
||||
continue
|
||||
|
||||
filler = next((r for r in remaining if len(r) <= fill_length), None)
|
||||
if filler is not None:
|
||||
remaining.remove(filler)
|
||||
combined.extend(filler)
|
||||
else:
|
||||
# steal from the next longest
|
||||
next_longest = remaining.pop()
|
||||
combined.extend(next_longest[0:fill_length])
|
||||
del next_longest[0:fill_length]
|
||||
if len(next_longest) > 0:
|
||||
remaining.append(next_longest)
|
||||
output.append(combined)
|
||||
|
||||
output.append(filler_items)
|
||||
return [i for o in output for i in o]
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue