diff --git a/data/data_loader.py b/data/data_loader.py index 9326926..e6b6ae7 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -147,7 +147,7 @@ class DataLoaderMultiAspect(): for batch_id in batch_ids: items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id]) # ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing - filler_items = chunk(items_by_batch_id[DEFAULT_BATCH_ID], batch_size) + filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size) custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] #custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size) neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, batch_size=grad_accum, filler_items=filler_items)