ensure predictable shuffle behaviour and further cleanup
This commit is contained in:
parent
1874a38663
commit
4f98f0bcc9
|
@ -139,18 +139,8 @@ class DataLoaderMultiAspect():
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
grad_accum=grad_accum)
|
grad_accum=grad_accum)
|
||||||
|
|
||||||
# chunk by effective batch size
|
|
||||||
effective_batch_size = batch_size * grad_accum
|
effective_batch_size = batch_size * grad_accum
|
||||||
chunks = chunk(items, effective_batch_size)
|
items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
|
||||||
# shuffle, but preserve the last chunk as last if it is incomplete
|
|
||||||
last_chunk = None
|
|
||||||
if len(chunks[-1]) < effective_batch_size:
|
|
||||||
last_chunk = chunks.pop(-1)
|
|
||||||
random.shuffle(chunks)
|
|
||||||
if last_chunk is not None:
|
|
||||||
chunks.append(last_chunk)
|
|
||||||
# un-chunk
|
|
||||||
items = unchunk(chunks)
|
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
@ -227,3 +217,21 @@ def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str
|
||||||
|
|
||||||
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
|
||||||
|
"""
|
||||||
|
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
|
||||||
|
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# chunk by effective batch size
|
||||||
|
chunks = chunk(l, chunk_size)
|
||||||
|
# preserve last chunk as last if it is incomplete
|
||||||
|
last_chunk = None
|
||||||
|
if len(chunks[-1]) < chunk_size:
|
||||||
|
last_chunk = chunks.pop(-1)
|
||||||
|
randomizer.shuffle(chunks)
|
||||||
|
if last_chunk is not None:
|
||||||
|
chunks.append(last_chunk)
|
||||||
|
l = unchunk(chunks)
|
||||||
|
return l
|
||||||
|
|
Loading…
Reference in New Issue