ensure predictable shuffle behaviour and further cleanup

This commit is contained in:
Damian Stewart 2023-06-08 10:50:51 +02:00
parent 1874a38663
commit 4f98f0bcc9
1 changed files with 19 additions and 11 deletions

View File

@ -139,18 +139,8 @@ class DataLoaderMultiAspect():
batch_size=batch_size, batch_size=batch_size,
grad_accum=grad_accum) grad_accum=grad_accum)
# chunk by effective batch size
effective_batch_size = batch_size * grad_accum effective_batch_size = batch_size * grad_accum
chunks = chunk(items, effective_batch_size) items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
# shuffle, but preserve the last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < effective_batch_size:
last_chunk = chunks.pop(-1)
random.shuffle(chunks)
if last_chunk is not None:
chunks.append(last_chunk)
# un-chunk
items = unchunk(chunks)
return items return items
@ -227,3 +217,21 @@ def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items) items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
return items return items
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
"""
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
"""
# chunk by effective batch size
chunks = chunk(l, chunk_size)
# preserve last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < chunk_size:
last_chunk = chunks.pop(-1)
randomizer.shuffle(chunks)
if last_chunk is not None:
chunks.append(last_chunk)
l = unchunk(chunks)
return l