2022-12-17 20:32:48 -07:00
"""
Copyright [ 2022 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-01-14 06:00:30 -07:00
import bisect
2023-02-08 03:28:45 -07:00
import logging
2023-02-08 06:15:54 -07:00
import os . path
from collections import defaultdict
2023-01-14 06:00:30 -07:00
import math
2023-02-08 11:04:12 -07:00
import copy
2023-01-07 11:57:23 -07:00
2022-12-17 20:32:48 -07:00
import random
2023-06-04 17:01:59 -06:00
from itertools import groupby
from typing import Tuple , List
from data . image_train_item import ImageTrainItem , DEFAULT_BATCH_ID
2023-02-08 06:15:54 -07:00
import PIL . Image
2023-01-01 08:45:18 -07:00
2023-06-07 10:07:37 -06:00
from utils . first_fit_decreasing import first_fit_decreasing
2023-01-01 08:45:18 -07:00
PIL . Image . MAX_IMAGE_PIXELS = 715827880 * 4 # increase decompression bomb error limit to 4x default
2022-12-17 20:32:48 -07:00
class DataLoaderMultiAspect ( ) :
"""
Data loader for multi - aspect - ratio training and bucketing
2023-02-06 23:10:34 -07:00
image_train_items : list of ` ImageTrainItem ` objects
2023-01-29 18:47:10 -07:00
seed : random seed
2022-12-17 20:32:48 -07:00
batch_size : number of images per batch
"""
2023-06-07 10:07:37 -06:00
def __init__ ( self , image_train_items : list [ ImageTrainItem ] , seed = 555 , batch_size = 1 , grad_accum = 1 ) :
2023-01-01 08:45:18 -07:00
self . seed = seed
self . batch_size = batch_size
2023-06-07 10:07:37 -06:00
self . grad_accum = grad_accum
2023-01-29 18:08:54 -07:00
self . prepared_train_data = image_train_items
random . Random ( self . seed ) . shuffle ( self . prepared_train_data )
self . prepared_train_data = sorted ( self . prepared_train_data , key = lambda img : img . caption . rating ( ) )
2023-02-08 11:04:12 -07:00
self . expected_epoch_size = math . floor ( sum ( [ i . multiplier for i in self . prepared_train_data ] ) )
if self . expected_epoch_size != len ( self . prepared_train_data ) :
logging . info ( f " * DLMA initialized with { len ( image_train_items ) } source images. After applying multipliers, each epoch will train on at least { self . expected_epoch_size } images. " )
2023-02-08 05:46:58 -07:00
else :
logging . info ( f " * DLMA initialized with { len ( image_train_items ) } images. " )
2023-01-29 18:08:54 -07:00
self . rating_overall_sum : float = 0.0
self . ratings_summed : list [ float ] = [ ]
2023-02-06 23:10:34 -07:00
self . __update_rating_sums ( )
2023-02-08 03:28:45 -07:00
2023-01-21 23:15:50 -07:00
2023-02-08 06:15:54 -07:00
def __pick_multiplied_set ( self , randomizer : random . Random ) :
2023-01-21 23:15:50 -07:00
"""
Deals with multiply . txt whole and fractional numbers
"""
picked_images = [ ]
2023-02-08 11:04:12 -07:00
data_copy = copy . deepcopy ( self . prepared_train_data ) # deep copy to avoid modifying original multiplier property
for iti in data_copy :
while iti . multiplier > = 1 :
2023-02-08 03:28:45 -07:00
picked_images . append ( iti )
2023-02-08 11:04:12 -07:00
iti . multiplier - = 1
2023-02-08 05:46:58 -07:00
2023-02-08 11:04:12 -07:00
remaining = self . expected_epoch_size - len ( picked_images )
2023-01-21 23:15:50 -07:00
2023-02-08 11:04:12 -07:00
assert remaining > = 0 , " Something went wrong with the multiplier calculation "
# resolve fractional parts, ensure each is only added max once
while remaining > 0 :
for iti in data_copy :
if randomizer . random ( ) < iti . multiplier :
picked_images . append ( iti )
iti . multiplier = 0
remaining - = 1
if remaining < = 0 :
break
2023-01-21 23:15:50 -07:00
return picked_images
2023-02-06 23:10:34 -07:00
def get_shuffled_image_buckets ( self , dropout_fraction : float = 1.0 ) - > list [ ImageTrainItem ] :
2023-01-14 06:00:30 -07:00
"""
2023-02-06 23:10:34 -07:00
Returns the current list of ` ImageTrainItem ` in randomized order ,
sorted into buckets with same sized images .
If dropout_fraction < 1.0 , only a subset of the images will be returned .
If dropout_fraction > = 1.0 , repicks fractional multipliers based on folder / multiply . txt values swept at prescan .
2023-01-14 06:00:30 -07:00
: param dropout_fraction : must be between 0.0 and 1.0 .
2023-02-06 23:10:34 -07:00
: return : Randomized list of ` ImageTrainItem ` objects
2023-01-14 06:00:30 -07:00
"""
self . seed + = 1
randomizer = random . Random ( self . seed )
if dropout_fraction < 1.0 :
picked_images = self . __pick_random_subset ( dropout_fraction , randomizer )
else :
2023-01-21 23:15:50 -07:00
picked_images = self . __pick_multiplied_set ( randomizer )
2023-01-14 06:00:30 -07:00
randomizer . shuffle ( picked_images )
buckets = { }
batch_size = self . batch_size
2023-06-07 10:07:37 -06:00
grad_accum = self . grad_accum
2023-06-04 17:01:59 -06:00
2023-01-14 06:00:30 -07:00
for image_caption_pair in picked_images :
image_caption_pair . runt_size = 0
2023-06-04 17:01:59 -06:00
batch_id = image_caption_pair . batch_id
2023-06-04 17:04:21 -06:00
bucket_key = ( batch_id , image_caption_pair . target_wh [ 0 ] , image_caption_pair . target_wh [ 1 ] )
if bucket_key not in buckets :
buckets [ bucket_key ] = [ ]
buckets [ bucket_key ] . append ( image_caption_pair )
2023-06-04 17:01:59 -06:00
2023-06-07 10:07:37 -06:00
# handle runts by randomly duplicating items
2023-06-04 17:04:21 -06:00
for bucket in buckets :
2023-02-08 05:46:58 -07:00
truncate_count = len ( buckets [ bucket ] ) % batch_size
if truncate_count > 0 :
runt_bucket = buckets [ bucket ] [ - truncate_count : ]
for item in runt_bucket :
item . runt_size = truncate_count
while len ( runt_bucket ) < batch_size :
runt_bucket . append ( random . choice ( runt_bucket ) )
current_bucket_size = len ( buckets [ bucket ] )
buckets [ bucket ] = buckets [ bucket ] [ : current_bucket_size - truncate_count ]
buckets [ bucket ] . extend ( runt_bucket )
2023-01-07 11:57:23 -07:00
2023-06-07 10:07:37 -06:00
def chunk ( l : List , chunk_size ) - > List :
num_chunks = int ( math . ceil ( float ( len ( l ) ) / chunk_size ) )
return [ l [ i * chunk_size : ( i + 1 ) * chunk_size ] for i in range ( num_chunks ) ]
def unchunk ( chunked_list : List ) :
return [ i for c in chunked_list for i in c ]
# interleave buckets while trying to maximise shared grad accum chunks
batch_ids = [ k [ 0 ] for k in buckets . keys ( ) ]
items_by_batch_id = { }
for batch_id in batch_ids :
items_by_batch_id [ batch_id ] = unchunk ( [ b for bucket_key , b in buckets . items ( ) if bucket_key [ 0 ] == batch_id ] )
# ensure we don't mix and match aspect ratios by treating each chunk of batch_size images as a single unit to pass to first_fit_decreasing
2023-06-07 10:39:13 -06:00
filler_items = chunk ( items_by_batch_id . get ( DEFAULT_BATCH_ID , [ ] ) , batch_size )
2023-06-07 10:07:37 -06:00
custom_batched_items = [ chunk ( v , batch_size ) for k , v in items_by_batch_id . items ( ) if k != DEFAULT_BATCH_ID ]
#custom_batched_items = chunk([b for bucket_key,b in buckets.items() if bucket_key[0] != DEFAULT_BATCH_ID], batch_size)
neighbourly_chunked_items = first_fit_decreasing ( custom_batched_items , batch_size = grad_accum , filler_items = filler_items )
items : List [ ImageTrainItem ] = unchunk ( neighbourly_chunked_items )
# chunk by effective batch size
effective_batch_size = batch_size * grad_accum
chunks = chunk ( items , effective_batch_size )
# shuffle, but preserve the last chunk as last if it is incomplete
last_chunk = None
if len ( chunks [ - 1 ] ) < effective_batch_size :
last_chunk = chunks . pop ( - 1 )
random . shuffle ( chunks )
if last_chunk is not None :
chunks . append ( last_chunk )
# un-chunk
items = unchunk ( chunks )
2022-12-17 20:32:48 -07:00
2023-02-06 23:10:34 -07:00
return items
2023-01-14 06:00:30 -07:00
2023-06-04 17:01:59 -06:00
2023-01-14 06:00:30 -07:00
def __pick_random_subset ( self , dropout_fraction : float , picker : random . Random ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating . Double that rating - > double the chance
: param dropout_fraction : must be between 0.0 and 1.0
: param picker : seeded random picker
: return : list of picked ImageTrainItem
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
prepared_train_data = self . prepared_train_data . copy ( )
ratings_summed = self . ratings_summed . copy ( )
rating_overall_sum = self . rating_overall_sum
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
num_images = len ( prepared_train_data )
num_images_to_pick = math . ceil ( num_images * dropout_fraction )
num_images_to_pick = max ( min ( num_images_to_pick , num_images ) , 0 )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
picked_images : list [ ImageTrainItem ] = [ ]
while num_images_to_pick > len ( picked_images ) :
# find random sample in dataset
point = picker . uniform ( 0.0 , rating_overall_sum )
pos = min ( bisect . bisect_left ( ratings_summed , point ) , len ( prepared_train_data ) - 1 )
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
# pick random sample
picked_image = prepared_train_data [ pos ]
picked_images . append ( picked_image )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# kick picked item out of data set to not pick it again
rating_overall_sum = max ( rating_overall_sum - picked_image . caption . rating ( ) , 0.0 )
ratings_summed . pop ( pos )
prepared_train_data . pop ( pos )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return picked_images
2023-02-06 23:10:34 -07:00
def __update_rating_sums ( self ) :
self . rating_overall_sum : float = 0.0
self . ratings_summed : list [ float ] = [ ]
for item in self . prepared_train_data :
self . rating_overall_sum + = item . caption . rating ( )
2023-06-04 17:01:59 -06:00
self . ratings_summed . append ( self . rating_overall_sum )