diff --git a/data/data_loader.py b/data/data_loader.py index 2ee0921..a42f28a 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -15,14 +15,11 @@ limitations under the License. """ import bisect import logging -import os.path -from collections import defaultdict import math import copy import random -from itertools import groupby -from typing import Tuple, List, Dict +from typing import List, Dict from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID import PIL.Image @@ -31,37 +28,6 @@ from utils.first_fit_decreasing import first_fit_decreasing PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default -def __chunk(l: List, chunk_size) -> List: - num_chunks = int(math.ceil(float(len(l)) / chunk_size)) - return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)] - -def __unchunk(chunked_list: List): - return [i for c in chunked_list for i in c] - - - - -def __collapse_buckets_by_batch_id(buckets: Dict) -> Dict: - # interleave buckets while trying to maximise shared grad accum chunks - batch_ids = [k[0] for k in buckets.keys()] - items_by_batch_id = {} - for batch_id in batch_ids: - items_by_batch_id[batch_id] = __unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id]) - return items_by_batch_id - -def __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]], - batch_size: int, - grad_accum: int) -> List[ImageTrainItem]: - # ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing - filler_items = __chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size) - custom_batched_items = [__chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] - neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, - batch_size=grad_accum, - filler_items=filler_items) - - items: List[ImageTrainItem] = __unchunk(neighbourly_chunked_items) - return items - class DataLoaderMultiAspect(): """ Data loader for multi-aspect-ratio training and bucketing @@ -144,8 +110,9 @@ class DataLoaderMultiAspect(): for image_caption_pair in picked_images: image_caption_pair.runt_size = 0 - batch_id = image_caption_pair.batch_id - bucket_key = (batch_id, image_caption_pair.target_wh[0], image_caption_pair.target_wh[1]) + bucket_key = (image_caption_pair.batch_id, + image_caption_pair.target_wh[0], + image_caption_pair.target_wh[1]) if bucket_key not in buckets: buckets[bucket_key] = [] buckets[bucket_key].append(image_caption_pair) @@ -167,14 +134,14 @@ class DataLoaderMultiAspect(): # handle batch_id # unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID. - items_by_batch_id = __collapse_buckets_by_batch_id(buckets) - items = __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id, + items_by_batch_id = collapse_buckets_by_batch_id(buckets) + items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id, batch_size=batch_size, grad_accum=grad_accum) # chunk by effective batch size effective_batch_size = batch_size * grad_accum - chunks = __chunk(items, effective_batch_size) + chunks = chunk(items, effective_batch_size) # shuffle, but preserve the last chunk as last if it is incomplete last_chunk = None if len(chunks[-1]) < effective_batch_size: @@ -183,7 +150,7 @@ class DataLoaderMultiAspect(): if last_chunk is not None: chunks.append(last_chunk) # un-chunk - items = __unchunk(chunks) + items = unchunk(chunks) return items @@ -232,3 +199,31 @@ class DataLoaderMultiAspect(): self.rating_overall_sum += item.caption.rating() self.ratings_summed.append(self.rating_overall_sum) + +def chunk(l: List, chunk_size) -> List: + num_chunks = int(math.ceil(float(len(l)) / chunk_size)) + return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)] + +def unchunk(chunked_list: List): + return [i for c in chunked_list for i in c] + +def collapse_buckets_by_batch_id(buckets: Dict) -> Dict: + # interleave buckets while trying to maximise shared grad accum chunks + batch_ids = [k[0] for k in buckets.keys()] + items_by_batch_id = {} + for batch_id in batch_ids: + items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id]) + return items_by_batch_id + +def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]], + batch_size: int, + grad_accum: int) -> List[ImageTrainItem]: + # ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing + filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size) + custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID] + neighbourly_chunked_items = first_fit_decreasing(custom_batched_items, + batch_size=grad_accum, + filler_items=filler_items) + + items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items) + return items