improve shuffle and runt handling for named buckets

This commit is contained in:
Damian Stewart 2023-06-14 09:18:18 +02:00
parent dadf881f9a
commit e4872fdc0c
1 changed files with 25 additions and 2 deletions

View File

@ -117,8 +117,19 @@ class DataLoaderMultiAspect():
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
# handle runts by randomly duplicating items
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
runt_count = len(bucket) % batch_size
if runt_count == 0:
continue
runts = bucket[-runt_count:]
del bucket[-runt_count:]
matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]]
buckets[matching_default_bucket_key].extend(runts)
# handle remaining runts by randomly duplicating items
for bucket in buckets:
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
@ -132,9 +143,21 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
# at this point items have a partially deterministic order
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
for k,v in items_by_batch_id.items()}
# paranoia: verify that this hasn't fucked up the aspect ratio batching
for items in items_by_batch_id.values():
batches = chunk(items, chunk_size=batch_size)
for batch in batches:
target_wh = batch[0].target_wh
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)