This commit is contained in:
Damian Stewart 2023-06-08 10:39:32 +02:00
parent b3c5d656e3
commit 1874a38663
1 changed files with 36 additions and 41 deletions

View File

@ -15,14 +15,11 @@ limitations under the License.
"""
import bisect
import logging
import os.path
from collections import defaultdict
import math
import copy
import random
from itertools import groupby
from typing import Tuple, List, Dict
from typing import List, Dict
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image
@ -31,37 +28,6 @@ from utils.first_fit_decreasing import first_fit_decreasing
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
def __chunk(l: List, chunk_size) -> List:
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
def __unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
def __collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
# interleave buckets while trying to maximise shared grad accum chunks
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
items_by_batch_id[batch_id] = __unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
return items_by_batch_id
def __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
batch_size: int,
grad_accum: int) -> List[ImageTrainItem]:
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
filler_items = __chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [__chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
batch_size=grad_accum,
filler_items=filler_items)
items: List[ImageTrainItem] = __unchunk(neighbourly_chunked_items)
return items
class DataLoaderMultiAspect():
"""
Data loader for multi-aspect-ratio training and bucketing
@ -144,8 +110,9 @@ class DataLoaderMultiAspect():
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
batch_id = image_caption_pair.batch_id
bucket_key = (batch_id, image_caption_pair.target_wh[0], image_caption_pair.target_wh[1])
bucket_key = (image_caption_pair.batch_id,
image_caption_pair.target_wh[0],
image_caption_pair.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
@ -167,14 +134,14 @@ class DataLoaderMultiAspect():
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = __collapse_buckets_by_batch_id(buckets)
items = __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)
# chunk by effective batch size
effective_batch_size = batch_size * grad_accum
chunks = __chunk(items, effective_batch_size)
chunks = chunk(items, effective_batch_size)
# shuffle, but preserve the last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < effective_batch_size:
@ -183,7 +150,7 @@ class DataLoaderMultiAspect():
if last_chunk is not None:
chunks.append(last_chunk)
# un-chunk
items = __unchunk(chunks)
items = unchunk(chunks)
return items
@ -232,3 +199,31 @@ class DataLoaderMultiAspect():
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
def chunk(l: List, chunk_size) -> List:
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
def unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
# interleave buckets while trying to maximise shared grad accum chunks
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
return items_by_batch_id
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
batch_size: int,
grad_accum: int) -> List[ImageTrainItem]:
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
batch_size=grad_accum,
filler_items=filler_items)
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
return items