diff --git a/data/data_loader.py b/data/data_loader.py index 8ad62a0..1ab9b67 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -198,7 +198,6 @@ def unchunk(chunked_list: List): return [i for c in chunked_list for i in c] def collapse_buckets_by_batch_id(buckets: Dict) -> Dict: - # interleave buckets while trying to maximise shared grad accum chunks batch_ids = [k[0] for k in buckets.keys()] items_by_batch_id = {} for batch_id in batch_ids: @@ -208,7 +207,10 @@ def collapse_buckets_by_batch_id(buckets: Dict) -> Dict: def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]], batch_size: int, grad_accum: int) -> List[ImageTrainItem]: - # ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing + # precondition: items_by_batch_id has no incomplete batches + assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values())) + # ensure we don't mix up aspect ratios by treating each chunk of batch_size images as + # a single unit to pass to first_fit_decreasing() filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size) custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,