improve shuffle and runt handling for named buckets
This commit is contained in:
parent
dadf881f9a
commit
e4872fdc0c
|
@ -117,8 +117,19 @@ class DataLoaderMultiAspect():
|
||||||
buckets[bucket_key] = []
|
buckets[bucket_key] = []
|
||||||
buckets[bucket_key].append(image_caption_pair)
|
buckets[bucket_key].append(image_caption_pair)
|
||||||
|
|
||||||
# handle runts by randomly duplicating items
|
# handled named batch runts by demoting them to the DEFAULT_BATCH_ID
|
||||||
|
for key, bucket in [(k, b) for k, b in buckets.items() if k[0] != DEFAULT_BATCH_ID]:
|
||||||
|
runt_count = len(bucket) % batch_size
|
||||||
|
if runt_count == 0:
|
||||||
|
continue
|
||||||
|
runts = bucket[-runt_count:]
|
||||||
|
del bucket[-runt_count:]
|
||||||
|
matching_default_bucket_key = [DEFAULT_BATCH_ID, key[1], key[2]]
|
||||||
|
buckets[matching_default_bucket_key].extend(runts)
|
||||||
|
|
||||||
|
# handle remaining runts by randomly duplicating items
|
||||||
for bucket in buckets:
|
for bucket in buckets:
|
||||||
|
assert bucket[0] == DEFAULT_BATCH_ID, "there should be no more runts in named batches"
|
||||||
truncate_count = len(buckets[bucket]) % batch_size
|
truncate_count = len(buckets[bucket]) % batch_size
|
||||||
if truncate_count > 0:
|
if truncate_count > 0:
|
||||||
runt_bucket = buckets[bucket][-truncate_count:]
|
runt_bucket = buckets[bucket][-truncate_count:]
|
||||||
|
@ -132,9 +143,21 @@ class DataLoaderMultiAspect():
|
||||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||||
buckets[bucket].extend(runt_bucket)
|
buckets[bucket].extend(runt_bucket)
|
||||||
|
|
||||||
|
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||||
|
# at this point items have a partially deterministic order
|
||||||
|
# (in particular: rarer aspect ratios are more likely to cluster at the end due to stochastic sampling)
|
||||||
|
# so we shuffle them to mitigate this, using chunked_shuffle to keep batches with the same aspect ratio together
|
||||||
|
items_by_batch_id = {k: chunked_shuffle(v, chunk_size=batch_size, randomizer=randomizer)
|
||||||
|
for k,v in items_by_batch_id.items()}
|
||||||
|
# paranoia: verify that this hasn't fucked up the aspect ratio batching
|
||||||
|
for items in items_by_batch_id.values():
|
||||||
|
batches = chunk(items, chunk_size=batch_size)
|
||||||
|
for batch in batches:
|
||||||
|
target_wh = batch[0].target_wh
|
||||||
|
assert all(target_wh == i.target_wh for i in batch[1:]), "mixed aspect ratios in a batch - this shouldn't happen"
|
||||||
|
|
||||||
# handle batch_id
|
# handle batch_id
|
||||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||||
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
|
||||||
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
grad_accum=grad_accum)
|
grad_accum=grad_accum)
|
||||||
|
|
Loading…
Reference in New Issue