From 4f98f0bcc98fe95908bdcaab796404a81055f794 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Thu, 8 Jun 2023 10:50:51 +0200 Subject: [PATCH] ensure predictable shuffle behaviour and further cleanup --- data/data_loader.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/data/data_loader.py b/data/data_loader.py index a42f28a..8ad62a0 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -139,18 +139,8 @@ class DataLoaderMultiAspect(): batch_size=batch_size, grad_accum=grad_accum) - # chunk by effective batch size effective_batch_size = batch_size * grad_accum - chunks = chunk(items, effective_batch_size) - # shuffle, but preserve the last chunk as last if it is incomplete - last_chunk = None - if len(chunks[-1]) < effective_batch_size: - last_chunk = chunks.pop(-1) - random.shuffle(chunks) - if last_chunk is not None: - chunks.append(last_chunk) - # un-chunk - items = unchunk(chunks) + items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer) return items @@ -227,3 +217,21 @@ def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items) return items + +def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List: + """ + Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk. + If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk) + """ + + # chunk by effective batch size + chunks = chunk(l, chunk_size) + # preserve last chunk as last if it is incomplete + last_chunk = None + if len(chunks[-1]) < chunk_size: + last_chunk = chunks.pop(-1) + randomizer.shuffle(chunks) + if last_chunk is not None: + chunks.append(last_chunk) + l = unchunk(chunks) + return l