assert an important precondition

This commit is contained in:
Damian Stewart 2023-06-08 11:01:16 +02:00
parent 4f98f0bcc9
commit a047294676
1 changed files with 4 additions and 2 deletions

View File

@ -198,7 +198,6 @@ def unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
# interleave buckets while trying to maximise shared grad accum chunks
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
@ -208,7 +207,10 @@ def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
batch_size: int,
grad_accum: int) -> List[ImageTrainItem]:
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
# precondition: items_by_batch_id has no incomplete batches
assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values()))
# ensure we don't mix up aspect ratios by treating each chunk of batch_size images as
# a single unit to pass to first_fit_decreasing()
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,