improve shuffle and runt handling for named buckets
This commit is contained in:
parent
dadf881f9a
commit
e4872fdc0c
|
@ -117,8 +117,19 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket_key] = []
|
||||
buckets[bucket_key].append(image_caption_pair)
|
||||
|
||||
# handle runts by randomly duplicating items
|
||||
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
|
||||
for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
|
||||
runt_count = len(bucket) % batch_size
|
||||
if runt_count == 0:
|
||||
continue
|
||||
runts = bucket[-runt_count:]
|
||||
del bucket[-runt_count:]
|
||||
matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]]
|
||||
buckets[matching_default_bucket_key].extend(runts)
|
||||
|
||||
# handle remaining runts by randomly duplicating items
|
||||
for bucket in buckets:
|
||||
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
|
@ -132,9 +143,21 @@ class DataLoaderMultiAspect():
|
|||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
# at this point items have a partially deterministic order
|
||||
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
|
||||
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
|
||||
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
|
||||
for k,v in items_by_batch_id.items()}
|
||||
# paranoia: verify that this hasn't fucked up the aspect ratio batching
|
||||
for items in items_by_batch_id.values():
|
||||
batches = chunk(items, chunk_size=batch_size)
|
||||
for batch in batches:
|
||||
target_wh = batch[0].target_wh
|
||||
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
|
||||
|
||||
# handle batch_id
|
||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||
batch_size=batch_size,
|
||||
grad_accum=grad_accum)
|
||||
|
|
Loading…
Reference in New Issue