ensure predictable shuffle behaviour and further cleanup
This commit is contained in:
parent
1874a38663
commit
4f98f0bcc9
|
@ -139,18 +139,8 @@ class DataLoaderMultiAspect():
|
|||
batch_size=batch_size,
|
||||
grad_accum=grad_accum)
|
||||
|
||||
# chunk by effective batch size
|
||||
effective_batch_size = batch_size * grad_accum
|
||||
chunks = chunk(items, effective_batch_size)
|
||||
# shuffle, but preserve the last chunk as last if it is incomplete
|
||||
last_chunk = None
|
||||
if len(chunks[-1]) < effective_batch_size:
|
||||
last_chunk = chunks.pop(-1)
|
||||
random.shuffle(chunks)
|
||||
if last_chunk is not None:
|
||||
chunks.append(last_chunk)
|
||||
# un-chunk
|
||||
items = unchunk(chunks)
|
||||
items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
|
||||
|
||||
return items
|
||||
|
||||
|
@ -227,3 +217,21 @@ def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str
|
|||
|
||||
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
||||
return items
|
||||
|
||||
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
|
||||
"""
|
||||
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
|
||||
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
|
||||
"""
|
||||
|
||||
# chunk by effective batch size
|
||||
chunks = chunk(l, chunk_size)
|
||||
# preserve last chunk as last if it is incomplete
|
||||
last_chunk = None
|
||||
if len(chunks[-1]) < chunk_size:
|
||||
last_chunk = chunks.pop(-1)
|
||||
randomizer.shuffle(chunks)
|
||||
if last_chunk is not None:
|
||||
chunks.append(last_chunk)
|
||||
l = unchunk(chunks)
|
||||
return l
|
||||
|
|
Loading…
Reference in New Issue