allow empty default_batch
This commit is contained in:
parent
59fc9891d4
commit
86f80a8776
|
@ -147,7 +147,7 @@ class DataLoaderMultiAspect():
|
||||||
for batch_id in batch_ids:
|
for batch_id in batch_ids:
|
||||||
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
||||||
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
||||||
filler_items = chunk(items_by_batch_id[DEFAULT_BATCH_ID], batch_size)
|
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
|
||||||
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
||||||
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
|
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
|
||||||
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)
|
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)
|
||||||
|
|
Loading…
Reference in New Issue