cleanup
This commit is contained in:
parent
b3c5d656e3
commit
1874a38663
|
@ -15,14 +15,11 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
import bisect
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
import os.path
|
|
||||||
from collections import defaultdict
|
|
||||||
import math
|
import math
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from itertools import groupby
|
from typing import List, Dict
|
||||||
from typing import Tuple, List, Dict
|
|
||||||
|
|
||||||
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
@ -31,37 +28,6 @@ from utils.first_fit_decreasing import first_fit_decreasing
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
|
||||||
def __chunk(l: List, chunk_size) -> List:
|
|
||||||
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
|
|
||||||
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
|
|
||||||
|
|
||||||
def __unchunk(chunked_list: List):
|
|
||||||
return [i for c in chunked_list for i in c]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
|
|
||||||
# interleave buckets while trying to maximise shared grad accum chunks
|
|
||||||
batch_ids = [k[0] for k in buckets.keys()]
|
|
||||||
items_by_batch_id = {}
|
|
||||||
for batch_id in batch_ids:
|
|
||||||
items_by_batch_id[batch_id] = __unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
|
||||||
return items_by_batch_id
|
|
||||||
|
|
||||||
def __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
|
|
||||||
batch_size: int,
|
|
||||||
grad_accum: int) -> List[ImageTrainItem]:
|
|
||||||
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
|
||||||
filler_items = __chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
|
|
||||||
custom_batched_items = [__chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
|
||||||
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
|
|
||||||
batch_size=grad_accum,
|
|
||||||
filler_items=filler_items)
|
|
||||||
|
|
||||||
items: List[ImageTrainItem] = __unchunk(neighbourly_chunked_items)
|
|
||||||
return items
|
|
||||||
|
|
||||||
class DataLoaderMultiAspect():
|
class DataLoaderMultiAspect():
|
||||||
"""
|
"""
|
||||||
Data loader for multi-aspect-ratio training and bucketing
|
Data loader for multi-aspect-ratio training and bucketing
|
||||||
|
@ -144,8 +110,9 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
for image_caption_pair in picked_images:
|
for image_caption_pair in picked_images:
|
||||||
image_caption_pair.runt_size = 0
|
image_caption_pair.runt_size = 0
|
||||||
batch_id = image_caption_pair.batch_id
|
bucket_key = (image_caption_pair.batch_id,
|
||||||
bucket_key = (batch_id, image_caption_pair.target_wh[0], image_caption_pair.target_wh[1])
|
image_caption_pair.target_wh[0],
|
||||||
|
image_caption_pair.target_wh[1])
|
||||||
if bucket_key not in buckets:
|
if bucket_key not in buckets:
|
||||||
buckets[bucket_key] = []
|
buckets[bucket_key] = []
|
||||||
buckets[bucket_key].append(image_caption_pair)
|
buckets[bucket_key].append(image_caption_pair)
|
||||||
|
@ -167,14 +134,14 @@ class DataLoaderMultiAspect():
|
||||||
|
|
||||||
# handle batch_id
|
# handle batch_id
|
||||||
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
|
||||||
items_by_batch_id = __collapse_buckets_by_batch_id(buckets)
|
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
|
||||||
items = __flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
grad_accum=grad_accum)
|
grad_accum=grad_accum)
|
||||||
|
|
||||||
# chunk by effective batch size
|
# chunk by effective batch size
|
||||||
effective_batch_size = batch_size * grad_accum
|
effective_batch_size = batch_size * grad_accum
|
||||||
chunks = __chunk(items, effective_batch_size)
|
chunks = chunk(items, effective_batch_size)
|
||||||
# shuffle, but preserve the last chunk as last if it is incomplete
|
# shuffle, but preserve the last chunk as last if it is incomplete
|
||||||
last_chunk = None
|
last_chunk = None
|
||||||
if len(chunks[-1]) < effective_batch_size:
|
if len(chunks[-1]) < effective_batch_size:
|
||||||
|
@ -183,7 +150,7 @@ class DataLoaderMultiAspect():
|
||||||
if last_chunk is not None:
|
if last_chunk is not None:
|
||||||
chunks.append(last_chunk)
|
chunks.append(last_chunk)
|
||||||
# un-chunk
|
# un-chunk
|
||||||
items = __unchunk(chunks)
|
items = unchunk(chunks)
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
@ -232,3 +199,31 @@ class DataLoaderMultiAspect():
|
||||||
self.rating_overall_sum += item.caption.rating()
|
self.rating_overall_sum += item.caption.rating()
|
||||||
self.ratings_summed.append(self.rating_overall_sum)
|
self.ratings_summed.append(self.rating_overall_sum)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(l: List, chunk_size) -> List:
|
||||||
|
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
|
||||||
|
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
|
||||||
|
|
||||||
|
def unchunk(chunked_list: List):
|
||||||
|
return [i for c in chunked_list for i in c]
|
||||||
|
|
||||||
|
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
|
||||||
|
# interleave buckets while trying to maximise shared grad accum chunks
|
||||||
|
batch_ids = [k[0] for k in buckets.keys()]
|
||||||
|
items_by_batch_id = {}
|
||||||
|
for batch_id in batch_ids:
|
||||||
|
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
|
||||||
|
return items_by_batch_id
|
||||||
|
|
||||||
|
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
|
||||||
|
batch_size: int,
|
||||||
|
grad_accum: int) -> List[ImageTrainItem]:
|
||||||
|
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
|
||||||
|
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
|
||||||
|
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
|
||||||
|
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
|
||||||
|
batch_size=grad_accum,
|
||||||
|
filler_items=filler_items)
|
||||||
|
|
||||||
|
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
|
||||||
|
return items
|
||||||
|
|
Loading…
Reference in New Issue