wip first fit decreasing for batch/grad accum shuffling
This commit is contained in:
parent
ba95b8c6d1
commit
a8455b9427
|
@ -0,0 +1,22 @@
|
|||
import unittest
|
||||
|
||||
from utils.first_fit_decreasing import first_fit_decreasing
|
||||
|
||||
class TestFirstFitDecreasing(unittest.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=2)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=3)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3, 4, 5, 6]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
input = [[1, 2, 3]]
|
||||
output = first_fit_decreasing(input, batch_size=4)
|
||||
self.assertEqual(output, [1, 2, 3])
|
|
@ -0,0 +1,16 @@
|
|||
from typing import List
|
||||
|
||||
def first_fit_decreasing(input_list: List[List], batch_size: int) -> List:
|
||||
"""
|
||||
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
|
||||
lists end up in the same batch.
|
||||
|
||||
@return a list of batches
|
||||
"""
|
||||
|
||||
def sort_by_length(items: List[List]):
|
||||
return items.sort(key=lambda x: len(x), reverse=True)
|
||||
|
||||
remaining = list(input_list)
|
||||
while remaining:
|
||||
remaining = sort_by_length(remaining)
|
Loading…
Reference in New Issue