From e4872fdc0c2a903acfa6b9bad23bcd887e10afb6 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 14 Jun 2023 09:18:18 +0200 Subject: [PATCH] improve shuffle and runt handling for named buckets --- data/data_loader.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index 1ab9b67..12706e8 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -117,8 +117,19 @@ class DataLoaderMultiAspect(): buckets[bucket_key] = [] buckets[bucket_key].append(image_caption_pair) - # handle runts by randomly duplicating items + # handled named batch runts by demoting them to the DEFAULT_BATCH_ID + for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]: + runt_count = len(bucket) % batch_size + if runt_count == 0: + continue + runts = bucket[-runt_count:] + del bucket[-runt_count:] + matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]] + buckets[matching_default_bucket_key].extend(runts) + + # handle remaining runts by randomly duplicating items for bucket in buckets: + assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches" truncate_count = len(buckets[bucket]) % batch_size if truncate_count > 0: runt_bucket = buckets[bucket][-truncate_count:] @@ -132,9 +143,21 @@ class DataLoaderMultiAspect(): buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count] buckets[bucket].extend(runt_bucket) + items_by_batch_id = collapse_buckets_by_batch_id(buckets) + # at this point items have a partially deterministic order + # (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling) + # so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together + items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer) + for k,v in items_by_batch_id.items()} + # paranoia: verify that this hasn't fucked up the aspect ratio batching + for items in items_by_batch_id.values(): + batches = chunk(items, chunk_size=batch_size) + for batch in batches: + target_wh = batch[0].target_wh + assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen" + # handle batch_id # unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID. - items_by_batch_id = collapse_buckets_by_batch_id(buckets) items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id, batch_size=batch_size, grad_accum=grad_accum)