assert an important precondition
This commit is contained in:
parent
4f98f0bcc9
commit
a047294676
|
@ -198,7 +198,6 @@ def unchunk(chunked_list: List):
|
|||
return [i for c in chunked_list for i in c]
|
||||
|
||||
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
|
||||
# interleave buckets while trying to maximise shared grad accum chunks
|
||||
batch_ids = [k[0] for k in buckets.keys()]
|
||||
items_by_batch_id = {}
|
||||
for batch_id in batch_ids:
|
||||
|
@ -208,7 +207,10 @@ def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
|
|||
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
|
||||
batch_size: int,
|
||||
grad_accum: int) -> List[ImageTrainItem]:
|
||||
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
||||
# precondition: items_by_batch_id has no incomplete batches
|
||||
assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values()))
|
||||
# ensure we don't mix up aspect ratios by treating each chunk of batch_size images as
|
||||
# a single unit to pass to first_fit_decreasing()
|
||||
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
|
||||
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
||||
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
|
||||
|
|
Loading…
Reference in New Issue