wip first fit decreasing for batch/grad accum shuffling

This commit is contained in:
Damian Stewart 2023-06-07 13:49:06 +02:00
parent ba95b8c6d1
commit a8455b9427
2 changed files with 38 additions and 0 deletions

View File

@ -0,0 +1,22 @@
import unittest
from utils.first_fit_decreasing import first_fit_decreasing
class TestFirstFitDecreasing(unittest.TestCase):
def test_basic(self):
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3])

View File

@ -0,0 +1,16 @@
from typing import List
def first_fit_decreasing(input_list: List[List], batch_size: int) -> List:
"""
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
lists end up in the same batch.
@return a list of batches
"""
def sort_by_length(items: List[List]):
return items.sort(key=lambda x: len(x), reverse=True)
remaining = list(input_list)
while remaining:
remaining = sort_by_length(remaining)