2022-12-17 20:32:48 -07:00
"""
Copyright [ 2022 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-01-14 06:00:30 -07:00
import bisect
2023-02-08 03:28:45 -07:00
import logging
2023-02-08 06:15:54 -07:00
import os . path
from collections import defaultdict
2023-01-14 06:00:30 -07:00
import math
2023-02-08 11:04:12 -07:00
import copy
2023-01-07 11:57:23 -07:00
2022-12-17 20:32:48 -07:00
import random
2023-04-19 03:06:02 -06:00
from colorama import Fore , Style
2023-02-08 06:15:54 -07:00
from data . image_train_item import ImageTrainItem
import PIL . Image
2023-01-01 08:45:18 -07:00
PIL . Image . MAX_IMAGE_PIXELS = 715827880 * 4 # increase decompression bomb error limit to 4x default
2022-12-17 20:32:48 -07:00
class DataLoaderMultiAspect ( ) :
"""
Data loader for multi - aspect - ratio training and bucketing
2023-02-06 23:10:34 -07:00
image_train_items : list of ` ImageTrainItem ` objects
2023-01-29 18:47:10 -07:00
seed : random seed
2022-12-17 20:32:48 -07:00
batch_size : number of images per batch
"""
2023-01-29 18:47:10 -07:00
def __init__ ( self , image_train_items : list [ ImageTrainItem ] , seed = 555 , batch_size = 1 ) :
2023-01-01 08:45:18 -07:00
self . seed = seed
self . batch_size = batch_size
2023-01-29 18:08:54 -07:00
self . prepared_train_data = image_train_items
random . Random ( self . seed ) . shuffle ( self . prepared_train_data )
self . prepared_train_data = sorted ( self . prepared_train_data , key = lambda img : img . caption . rating ( ) )
2023-02-08 11:04:12 -07:00
self . expected_epoch_size = math . floor ( sum ( [ i . multiplier for i in self . prepared_train_data ] ) )
if self . expected_epoch_size != len ( self . prepared_train_data ) :
logging . info ( f " * DLMA initialized with { len ( image_train_items ) } source images. After applying multipliers, each epoch will train on at least { self . expected_epoch_size } images. " )
2023-02-08 05:46:58 -07:00
else :
logging . info ( f " * DLMA initialized with { len ( image_train_items ) } images. " )
2023-01-29 18:08:54 -07:00
self . rating_overall_sum : float = 0.0
self . ratings_summed : list [ float ] = [ ]
2023-02-06 23:10:34 -07:00
self . __update_rating_sums ( )
2023-02-08 03:28:45 -07:00
2023-01-21 23:15:50 -07:00
2023-02-08 06:15:54 -07:00
def __pick_multiplied_set ( self , randomizer : random . Random ) :
2023-01-21 23:15:50 -07:00
"""
Deals with multiply . txt whole and fractional numbers
"""
picked_images = [ ]
2023-02-08 11:04:12 -07:00
data_copy = copy . deepcopy ( self . prepared_train_data ) # deep copy to avoid modifying original multiplier property
for iti in data_copy :
while iti . multiplier > = 1 :
2023-02-08 03:28:45 -07:00
picked_images . append ( iti )
2023-02-08 11:04:12 -07:00
iti . multiplier - = 1
2023-02-08 05:46:58 -07:00
2023-02-08 11:04:12 -07:00
remaining = self . expected_epoch_size - len ( picked_images )
2023-01-21 23:15:50 -07:00
2023-02-08 11:04:12 -07:00
assert remaining > = 0 , " Something went wrong with the multiplier calculation "
# resolve fractional parts, ensure each is only added max once
while remaining > 0 :
for iti in data_copy :
if randomizer . random ( ) < iti . multiplier :
picked_images . append ( iti )
iti . multiplier = 0
remaining - = 1
if remaining < = 0 :
break
2023-01-21 23:15:50 -07:00
return picked_images
2023-02-06 23:10:34 -07:00
def get_shuffled_image_buckets ( self , dropout_fraction : float = 1.0 ) - > list [ ImageTrainItem ] :
2023-01-14 06:00:30 -07:00
"""
2023-02-06 23:10:34 -07:00
Returns the current list of ` ImageTrainItem ` in randomized order ,
sorted into buckets with same sized images .
If dropout_fraction < 1.0 , only a subset of the images will be returned .
If dropout_fraction > = 1.0 , repicks fractional multipliers based on folder / multiply . txt values swept at prescan .
2023-01-14 06:00:30 -07:00
: param dropout_fraction : must be between 0.0 and 1.0 .
2023-02-06 23:10:34 -07:00
: return : Randomized list of ` ImageTrainItem ` objects
2023-01-14 06:00:30 -07:00
"""
self . seed + = 1
randomizer = random . Random ( self . seed )
if dropout_fraction < 1.0 :
picked_images = self . __pick_random_subset ( dropout_fraction , randomizer )
else :
2023-01-21 23:15:50 -07:00
picked_images = self . __pick_multiplied_set ( randomizer )
2023-01-14 06:00:30 -07:00
randomizer . shuffle ( picked_images )
buckets = { }
batch_size = self . batch_size
for image_caption_pair in picked_images :
image_caption_pair . runt_size = 0
target_wh = image_caption_pair . target_wh
if ( target_wh [ 0 ] , target_wh [ 1 ] ) not in buckets :
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] = [ ]
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] . append ( image_caption_pair )
2023-02-08 05:46:58 -07:00
for bucket in buckets :
truncate_count = len ( buckets [ bucket ] ) % batch_size
if truncate_count > 0 :
runt_bucket = buckets [ bucket ] [ - truncate_count : ]
for item in runt_bucket :
item . runt_size = truncate_count
2023-04-19 03:06:02 -06:00
appended_dupes = 0
2023-02-08 05:46:58 -07:00
while len ( runt_bucket ) < batch_size :
2023-04-19 03:06:02 -06:00
appended_dupes + = 1
2023-02-08 05:46:58 -07:00
runt_bucket . append ( random . choice ( runt_bucket ) )
current_bucket_size = len ( buckets [ bucket ] )
buckets [ bucket ] = buckets [ bucket ] [ : current_bucket_size - truncate_count ]
buckets [ bucket ] . extend ( runt_bucket )
2023-01-07 11:57:23 -07:00
2023-01-14 06:00:30 -07:00
# flatten the buckets
2023-02-06 23:10:34 -07:00
items : list [ ImageTrainItem ] = [ ]
2023-01-14 06:00:30 -07:00
for bucket in buckets :
2023-02-06 23:10:34 -07:00
items . extend ( buckets [ bucket ] )
2022-12-17 20:32:48 -07:00
2023-02-06 23:10:34 -07:00
return items
2023-01-14 06:00:30 -07:00
def __pick_random_subset ( self , dropout_fraction : float , picker : random . Random ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating . Double that rating - > double the chance
: param dropout_fraction : must be between 0.0 and 1.0
: param picker : seeded random picker
: return : list of picked ImageTrainItem
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
prepared_train_data = self . prepared_train_data . copy ( )
ratings_summed = self . ratings_summed . copy ( )
rating_overall_sum = self . rating_overall_sum
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
num_images = len ( prepared_train_data )
num_images_to_pick = math . ceil ( num_images * dropout_fraction )
num_images_to_pick = max ( min ( num_images_to_pick , num_images ) , 0 )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
picked_images : list [ ImageTrainItem ] = [ ]
while num_images_to_pick > len ( picked_images ) :
# find random sample in dataset
point = picker . uniform ( 0.0 , rating_overall_sum )
pos = min ( bisect . bisect_left ( ratings_summed , point ) , len ( prepared_train_data ) - 1 )
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
# pick random sample
picked_image = prepared_train_data [ pos ]
picked_images . append ( picked_image )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# kick picked item out of data set to not pick it again
rating_overall_sum = max ( rating_overall_sum - picked_image . caption . rating ( ) , 0.0 )
ratings_summed . pop ( pos )
prepared_train_data . pop ( pos )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return picked_images
2023-02-06 23:10:34 -07:00
def __update_rating_sums ( self ) :
self . rating_overall_sum : float = 0.0
self . ratings_summed : list [ float ] = [ ]
for item in self . prepared_train_data :
self . rating_overall_sum + = item . caption . rating ( )
2023-04-19 03:06:02 -06:00
self . ratings_summed . append ( self . rating_overall_sum )