2022-12-17 20:32:48 -07:00
"""
Copyright [ 2022 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-01-14 06:00:30 -07:00
import bisect
import math
2022-12-17 20:32:48 -07:00
import os
2022-12-27 12:25:32 -07:00
import logging
2023-01-22 16:59:59 -07:00
import copy
2023-01-07 11:57:23 -07:00
2022-12-17 20:32:48 -07:00
import random
2023-01-23 01:15:32 -07:00
from data . image_train_item import ImageTrainItem
2022-12-17 20:32:48 -07:00
import data . aspects as aspects
2023-01-23 00:13:05 -07:00
import data . resolver as resolver
2022-12-27 12:25:32 -07:00
from colorama import Fore , Style
2023-01-01 08:45:18 -07:00
import PIL
PIL . Image . MAX_IMAGE_PIXELS = 715827880 * 4 # increase decompression bomb error limit to 4x default
2022-12-17 20:32:48 -07:00
class DataLoaderMultiAspect ( ) :
"""
Data loader for multi - aspect - ratio training and bucketing
data_root : root folder of training data
batch_size : number of images per batch
flip_p : probability of flipping image horizontally ( i . e . 0 - 0.5 )
"""
2022-12-27 12:25:32 -07:00
def __init__ ( self , data_root , seed = 555 , debug_level = 0 , batch_size = 1 , flip_p = 0.0 , resolution = 512 , log_folder = None ) :
2023-01-23 01:15:32 -07:00
self . data_root = data_root
2022-12-17 20:32:48 -07:00
self . debug_level = debug_level
self . flip_p = flip_p
2022-12-27 12:25:32 -07:00
self . log_folder = log_folder
2023-01-01 08:45:18 -07:00
self . seed = seed
self . batch_size = batch_size
2023-01-20 07:42:24 -07:00
self . has_scanned = False
2022-12-17 20:32:48 -07:00
self . aspects = aspects . get_aspect_buckets ( resolution = resolution , square_only = False )
2023-01-23 01:15:32 -07:00
2022-12-27 12:25:32 -07:00
logging . info ( f " * DLMA resolution { resolution } , buckets: { self . aspects } " )
2023-01-23 01:15:32 -07:00
self . __prepare_train_data ( )
2023-01-14 06:00:30 -07:00
( self . rating_overall_sum , self . ratings_summed ) = self . __sort_and_precalc_image_ratings ( )
2023-01-21 23:15:50 -07:00
def __pick_multiplied_set ( self , randomizer ) :
"""
Deals with multiply . txt whole and fractional numbers
"""
2023-01-22 16:59:59 -07:00
#print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy . deepcopy ( self . prepared_train_data ) # deep copy to avoid modifying original multiplier property
2023-01-21 23:15:50 -07:00
epoch_size = len ( self . prepared_train_data )
picked_images = [ ]
# add by whole number part first and decrement multiplier in copy
2023-01-22 16:59:59 -07:00
for iti in data_copy :
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
2023-01-21 23:15:50 -07:00
while iti . multiplier > = 1.0 :
picked_images . append ( iti )
2023-01-22 16:59:59 -07:00
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
iti . multiplier - = 1.0
2023-01-21 23:15:50 -07:00
remaining = epoch_size - len ( picked_images )
assert remaining > = 0 , " Something went wrong with the multiplier calculation "
2023-01-27 11:58:14 -07:00
# add by remaining fractional numbers by random chance
2023-01-21 23:15:50 -07:00
while remaining > 0 :
2023-01-22 16:59:59 -07:00
for iti in data_copy :
if randomizer . uniform ( 0.0 , 1.0 ) < iti . multiplier :
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
2023-01-21 23:15:50 -07:00
picked_images . append ( iti )
remaining - = 1
2023-01-27 11:58:14 -07:00
iti . multiplier = 0.0
2023-01-21 23:15:50 -07:00
if remaining < = 0 :
break
2023-01-22 16:59:59 -07:00
del data_copy
2023-01-21 23:15:50 -07:00
return picked_images
2023-01-14 06:00:30 -07:00
def get_shuffled_image_buckets ( self , dropout_fraction : float = 1.0 ) :
"""
returns the current list of images including their captions in a randomized order ,
sorted into buckets with same sized images
if dropout_fraction < 1.0 , only a subset of the images will be returned
2023-01-21 23:15:50 -07:00
if dropout_fraction > = 1.0 , repicks fractional multipliers based on folder / multiply . txt values swept at prescan
2023-01-14 06:00:30 -07:00
: param dropout_fraction : must be between 0.0 and 1.0 .
: return : randomized list of ( image , caption ) pairs , sorted into same sized buckets
"""
self . seed + = 1
randomizer = random . Random ( self . seed )
if dropout_fraction < 1.0 :
picked_images = self . __pick_random_subset ( dropout_fraction , randomizer )
else :
2023-01-21 23:15:50 -07:00
picked_images = self . __pick_multiplied_set ( randomizer )
2023-01-14 06:00:30 -07:00
randomizer . shuffle ( picked_images )
buckets = { }
batch_size = self . batch_size
for image_caption_pair in picked_images :
image_caption_pair . runt_size = 0
target_wh = image_caption_pair . target_wh
if ( target_wh [ 0 ] , target_wh [ 1 ] ) not in buckets :
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] = [ ]
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] . append ( image_caption_pair )
if len ( buckets ) > 1 :
for bucket in buckets :
truncate_count = len ( buckets [ bucket ] ) % batch_size
if truncate_count > 0 :
runt_bucket = buckets [ bucket ] [ - truncate_count : ]
for item in runt_bucket :
item . runt_size = truncate_count
while len ( runt_bucket ) < batch_size :
runt_bucket . append ( random . choice ( runt_bucket ) )
current_bucket_size = len ( buckets [ bucket ] )
buckets [ bucket ] = buckets [ bucket ] [ : current_bucket_size - truncate_count ]
buckets [ bucket ] . extend ( runt_bucket )
2023-01-07 11:57:23 -07:00
2023-01-14 06:00:30 -07:00
# flatten the buckets
image_caption_pairs = [ ]
for bucket in buckets :
image_caption_pairs . extend ( buckets [ bucket ] )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return image_caption_pairs
def __sort_and_precalc_image_ratings ( self ) - > tuple [ float , list [ float ] ] :
self . prepared_train_data = sorted ( self . prepared_train_data , key = lambda img : img . caption . rating ( ) )
rating_overall_sum : float = 0.0
ratings_summed : list [ float ] = [ ]
for image in self . prepared_train_data :
rating_overall_sum + = image . caption . rating ( )
ratings_summed . append ( rating_overall_sum )
return rating_overall_sum , ratings_summed
2022-12-17 20:32:48 -07:00
2023-01-23 01:15:32 -07:00
def __prepare_train_data ( self , flip_p = 0.0 ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
Create ImageTrainItem objects with metadata for hydration later
"""
2023-01-20 07:42:24 -07:00
if not self . has_scanned :
self . has_scanned = True
2023-01-23 09:43:23 -07:00
logging . info ( " Preloading images... " )
2023-01-23 13:00:42 -07:00
items = resolver . resolve ( self . data_root , self . aspects , flip_p = flip_p , seed = self . seed )
2023-01-23 09:43:23 -07:00
image_paths = set ( map ( lambda item : item . pathname , items ) )
print ( f " * DLMA: { len ( items ) } images loaded from { len ( image_paths ) } files " )
2023-01-24 09:33:20 -07:00
self . prepared_train_data = [ item for item in items if item . error is None ]
2023-01-23 01:15:32 -07:00
random . Random ( self . seed ) . shuffle ( self . prepared_train_data )
2023-01-23 13:00:42 -07:00
self . __report_errors ( items )
2023-01-23 01:15:32 -07:00
2023-01-23 13:00:42 -07:00
def __report_errors ( self , items : list [ ImageTrainItem ] ) :
for item in items :
if item . error is not None :
2023-01-23 17:57:02 -07:00
logging . error ( f " { Fore . LIGHTRED_EX } *** Error opening { Fore . LIGHTYELLOW_EX } { item . pathname } { Fore . LIGHTRED_EX } to get metadata. File may be corrupt and will be skipped. { Style . RESET_ALL } " )
2023-01-23 13:00:42 -07:00
logging . error ( f " *** exception: { item . error } " )
undersized_items = [ item for item in items if item . is_undersized ]
2023-01-23 01:15:32 -07:00
2023-01-23 13:00:42 -07:00
if len ( undersized_items ) > 0 :
2023-01-23 01:15:32 -07:00
underized_log_path = os . path . join ( self . log_folder , " undersized_images.txt " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Some images are smaller than the target size, consider using larger images { Style . RESET_ALL } " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Check { underized_log_path } for more information. { Style . RESET_ALL } " )
with open ( underized_log_path , " w " ) as undersized_images_file :
undersized_images_file . write ( f " The following images are smaller than the target size, consider removing or sourcing a larger copy: " )
2023-01-23 17:57:02 -07:00
for undersized_item in undersized_items :
2023-01-27 11:58:14 -07:00
message = f " *** { undersized_item . pathname } with size: { undersized_item . image_size } is smaller than target size: { undersized_item . target_wh } \n "
2023-01-23 01:15:32 -07:00
undersized_images_file . write ( message )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
def __pick_random_subset ( self , dropout_fraction : float , picker : random . Random ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating . Double that rating - > double the chance
: param dropout_fraction : must be between 0.0 and 1.0
: param picker : seeded random picker
: return : list of picked ImageTrainItem
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
prepared_train_data = self . prepared_train_data . copy ( )
ratings_summed = self . ratings_summed . copy ( )
rating_overall_sum = self . rating_overall_sum
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
num_images = len ( prepared_train_data )
num_images_to_pick = math . ceil ( num_images * dropout_fraction )
num_images_to_pick = max ( min ( num_images_to_pick , num_images ) , 0 )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
picked_images : list [ ImageTrainItem ] = [ ]
while num_images_to_pick > len ( picked_images ) :
# find random sample in dataset
point = picker . uniform ( 0.0 , rating_overall_sum )
pos = min ( bisect . bisect_left ( ratings_summed , point ) , len ( prepared_train_data ) - 1 )
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
# pick random sample
picked_image = prepared_train_data [ pos ]
picked_images . append ( picked_image )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# kick picked item out of data set to not pick it again
rating_overall_sum = max ( rating_overall_sum - picked_image . caption . rating ( ) , 0.0 )
ratings_summed . pop ( pos )
prepared_train_data . pop ( pos )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return picked_images