allow empty default_batch

This commit is contained in:
damian 2023-06-07 18:39:13 +02:00
parent 59fc9891d4
commit 86f80a8776
1 changed files with 1 additions and 1 deletions

View File

@ -147,7 +147,7 @@ class DataLoaderMultiAspect():
for batch_id in batch_ids: for batch_id in batch_ids:
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id]) items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing # ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
filler_items = chunk(items_by_batch_id[DEFAULT_BATCH_ID], batch_size) filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size) #custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items) neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)