2022-12-17 20:32:48 -07:00
"""
Copyright [ 2022 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2023-01-14 06:00:30 -07:00
import bisect
import math
2022-12-17 20:32:48 -07:00
import os
2022-12-27 12:25:32 -07:00
import logging
2023-01-07 11:57:23 -07:00
import yaml
2022-12-17 20:32:48 -07:00
from PIL import Image
import random
2023-01-07 09:29:09 -07:00
from data . image_train_item import ImageTrainItem , ImageCaption
2022-12-17 20:32:48 -07:00
import data . aspects as aspects
2022-12-27 12:25:32 -07:00
from colorama import Fore , Style
2022-12-29 19:11:06 -07:00
import zipfile
2023-01-01 08:45:18 -07:00
import tqdm
import PIL
PIL . Image . MAX_IMAGE_PIXELS = 715827880 * 4 # increase decompression bomb error limit to 4x default
2022-12-17 20:32:48 -07:00
2023-01-07 14:59:51 -07:00
DEFAULT_MAX_CAPTION_LENGTH = 2048
2022-12-17 20:32:48 -07:00
class DataLoaderMultiAspect ( ) :
"""
Data loader for multi - aspect - ratio training and bucketing
data_root : root folder of training data
batch_size : number of images per batch
flip_p : probability of flipping image horizontally ( i . e . 0 - 0.5 )
"""
2022-12-27 12:25:32 -07:00
def __init__ ( self , data_root , seed = 555 , debug_level = 0 , batch_size = 1 , flip_p = 0.0 , resolution = 512 , log_folder = None ) :
2022-12-17 20:32:48 -07:00
self . image_paths = [ ]
self . debug_level = debug_level
self . flip_p = flip_p
2022-12-27 12:25:32 -07:00
self . log_folder = log_folder
2023-01-01 08:45:18 -07:00
self . seed = seed
self . batch_size = batch_size
2023-01-20 07:42:24 -07:00
self . has_scanned = False
2022-12-17 20:32:48 -07:00
self . aspects = aspects . get_aspect_buckets ( resolution = resolution , square_only = False )
2022-12-27 12:25:32 -07:00
logging . info ( f " * DLMA resolution { resolution } , buckets: { self . aspects } " )
logging . info ( " Preloading images... " )
2022-12-17 20:32:48 -07:00
2022-12-29 19:11:06 -07:00
self . unzip_all ( data_root )
2022-12-17 20:32:48 -07:00
self . __recurse_data_root ( self = self , recurse_root = data_root )
random . Random ( seed ) . shuffle ( self . image_paths )
2023-01-14 06:00:30 -07:00
self . prepared_train_data = self . __prescan_images ( self . image_paths , flip_p )
( self . rating_overall_sum , self . ratings_summed ) = self . __sort_and_precalc_image_ratings ( )
2023-01-21 23:15:50 -07:00
def __pick_multiplied_set ( self , randomizer ) :
"""
Deals with multiply . txt whole and fractional numbers
"""
prepared_train_data_local = self . prepared_train_data . copy ( )
epoch_size = len ( self . prepared_train_data )
picked_images = [ ]
# add by whole number part first and decrement multiplier in copy
for iti in prepared_train_data_local :
while iti . multiplier > = 1.0 :
picked_images . append ( iti )
iti . multiplier - = 1
if iti . multiplier == 0 :
prepared_train_data_local . remove ( iti )
remaining = epoch_size - len ( picked_images )
assert remaining > = 0 , " Something went wrong with the multiplier calculation "
# add by renaming fractional numbers by random chance
while remaining > 0 :
for iti in prepared_train_data_local :
if randomizer . uniform ( 0.0 , 1 ) < iti . multiplier :
picked_images . append ( iti )
remaining - = 1
prepared_train_data_local . remove ( iti )
if remaining < = 0 :
break
return picked_images
2023-01-14 06:00:30 -07:00
def get_shuffled_image_buckets ( self , dropout_fraction : float = 1.0 ) :
"""
returns the current list of images including their captions in a randomized order ,
sorted into buckets with same sized images
if dropout_fraction < 1.0 , only a subset of the images will be returned
2023-01-21 23:15:50 -07:00
if dropout_fraction > = 1.0 , repicks fractional multipliers based on folder / multiply . txt values swept at prescan
2023-01-14 06:00:30 -07:00
: param dropout_fraction : must be between 0.0 and 1.0 .
: return : randomized list of ( image , caption ) pairs , sorted into same sized buckets
"""
self . seed + = 1
randomizer = random . Random ( self . seed )
if dropout_fraction < 1.0 :
picked_images = self . __pick_random_subset ( dropout_fraction , randomizer )
else :
2023-01-21 23:15:50 -07:00
picked_images = self . __pick_multiplied_set ( randomizer )
2023-01-14 06:00:30 -07:00
randomizer . shuffle ( picked_images )
buckets = { }
batch_size = self . batch_size
for image_caption_pair in picked_images :
image_caption_pair . runt_size = 0
target_wh = image_caption_pair . target_wh
if ( target_wh [ 0 ] , target_wh [ 1 ] ) not in buckets :
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] = [ ]
buckets [ ( target_wh [ 0 ] , target_wh [ 1 ] ) ] . append ( image_caption_pair )
if len ( buckets ) > 1 :
for bucket in buckets :
truncate_count = len ( buckets [ bucket ] ) % batch_size
if truncate_count > 0 :
runt_bucket = buckets [ bucket ] [ - truncate_count : ]
for item in runt_bucket :
item . runt_size = truncate_count
while len ( runt_bucket ) < batch_size :
runt_bucket . append ( random . choice ( runt_bucket ) )
current_bucket_size = len ( buckets [ bucket ] )
buckets [ bucket ] = buckets [ bucket ] [ : current_bucket_size - truncate_count ]
buckets [ bucket ] . extend ( runt_bucket )
2023-01-07 11:57:23 -07:00
2023-01-14 06:00:30 -07:00
# flatten the buckets
image_caption_pairs = [ ]
for bucket in buckets :
image_caption_pairs . extend ( buckets [ bucket ] )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return image_caption_pairs
@staticmethod
def unzip_all ( path ) :
2022-12-29 19:11:06 -07:00
try :
for root , dirs , files in os . walk ( path ) :
for file in files :
if file . endswith ( ' .zip ' ) :
logging . info ( f " Unzipping { file } " )
with zipfile . ZipFile ( path , ' r ' ) as zip_ref :
zip_ref . extractall ( path )
except Exception as e :
logging . error ( f " Error unzipping files { e } " )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
def __sort_and_precalc_image_ratings ( self ) - > tuple [ float , list [ float ] ] :
self . prepared_train_data = sorted ( self . prepared_train_data , key = lambda img : img . caption . rating ( ) )
rating_overall_sum : float = 0.0
ratings_summed : list [ float ] = [ ]
for image in self . prepared_train_data :
rating_overall_sum + = image . caption . rating ( )
ratings_summed . append ( rating_overall_sum )
return rating_overall_sum , ratings_summed
2022-12-17 20:32:48 -07:00
@staticmethod
2023-01-07 09:29:09 -07:00
def __read_caption_from_file ( file_path , fallback_caption : ImageCaption ) - > ImageCaption :
2022-12-17 20:32:48 -07:00
try :
with open ( file_path , encoding = ' utf-8 ' , mode = ' r ' ) as caption_file :
2023-01-07 09:29:09 -07:00
caption_text = caption_file . read ( )
caption = DataLoaderMultiAspect . __split_caption_into_tags ( caption_text )
2022-12-17 20:32:48 -07:00
except :
2022-12-27 12:25:32 -07:00
logging . error ( f " *** Error reading { file_path } to get caption, falling back to filename " )
2022-12-17 20:32:48 -07:00
caption = fallback_caption
pass
return caption
2023-01-07 11:57:23 -07:00
@staticmethod
def __read_caption_from_yaml ( file_path : str , fallback_caption : ImageCaption ) - > ImageCaption :
with open ( file_path , " r " ) as stream :
try :
file_content = yaml . safe_load ( stream )
main_prompt = file_content . get ( " main_prompt " , " " )
2023-01-14 06:00:30 -07:00
rating = file_content . get ( " rating " , 1.0 )
2023-01-07 11:57:23 -07:00
unparsed_tags = file_content . get ( " tags " , [ ] )
2023-01-07 14:59:51 -07:00
max_caption_length = file_content . get ( " max_caption_length " , DEFAULT_MAX_CAPTION_LENGTH )
2023-01-07 11:57:23 -07:00
tags = [ ]
tag_weights = [ ]
2023-01-07 14:59:51 -07:00
last_weight = None
weights_differ = False
2023-01-07 11:57:23 -07:00
for unparsed_tag in unparsed_tags :
tag = unparsed_tag . get ( " tag " , " " ) . strip ( )
if len ( tag ) == 0 :
continue
tags . append ( tag )
2023-01-07 14:59:51 -07:00
tag_weight = unparsed_tag . get ( " weight " , 1.0 )
tag_weights . append ( tag_weight )
if last_weight is not None and weights_differ is False :
weights_differ = last_weight != tag_weight
last_weight = tag_weight
2023-01-07 11:57:23 -07:00
2023-01-14 06:00:30 -07:00
return ImageCaption ( main_prompt , rating , tags , tag_weights , max_caption_length , weights_differ )
2023-01-07 11:57:23 -07:00
except :
logging . error ( f " *** Error reading { file_path } to get caption, falling back to filename " )
return fallback_caption
2023-01-07 09:29:09 -07:00
@staticmethod
def __split_caption_into_tags ( caption_string : str ) - > ImageCaption :
"""
Splits a string by " , " into the main prompt and additional tags with equal weights
"""
split_caption = caption_string . split ( " , " )
main_prompt = split_caption . pop ( 0 ) . strip ( )
tags = [ ]
for tag in split_caption :
tags . append ( tag . strip ( ) )
2023-01-14 06:00:30 -07:00
return ImageCaption ( main_prompt , 1.0 , tags , [ 1.0 ] * len ( tags ) , DEFAULT_MAX_CAPTION_LENGTH , False )
2023-01-07 09:29:09 -07:00
2023-01-14 06:00:30 -07:00
def __prescan_images ( self , image_paths : list , flip_p = 0.0 ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
Create ImageTrainItem objects with metadata for hydration later
"""
decorated_image_train_items = [ ]
2023-01-20 07:42:24 -07:00
if not self . has_scanned :
undersized_images = [ ]
2023-01-21 23:15:50 -07:00
multipliers = { }
skip_folders = [ ]
2023-01-01 08:45:18 -07:00
for pathname in tqdm . tqdm ( image_paths ) :
2022-12-17 20:32:48 -07:00
caption_from_filename = os . path . splitext ( os . path . basename ( pathname ) ) [ 0 ] . split ( " _ " ) [ 0 ]
2023-01-07 09:29:09 -07:00
caption = DataLoaderMultiAspect . __split_caption_into_tags ( caption_from_filename )
2022-12-17 20:32:48 -07:00
2023-01-07 11:57:23 -07:00
file_path_without_ext = os . path . splitext ( pathname ) [ 0 ]
yaml_file_path = file_path_without_ext + " .yaml "
txt_file_path = file_path_without_ext + " .txt "
caption_file_path = file_path_without_ext + " .caption "
2022-12-17 20:32:48 -07:00
2023-01-21 23:15:50 -07:00
current_dir = os . path . dirname ( pathname )
try :
if current_dir not in multipliers :
multiply_txt_path = os . path . join ( current_dir , " multiply.txt " )
#print(current_dir, multiply_txt_path)
if os . path . exists ( multiply_txt_path ) :
with open ( multiply_txt_path , ' r ' ) as f :
val = float ( f . read ( ) . strip ( ) )
multipliers [ current_dir ] = val
logging . info ( f " * DLMA multiply.txt in { current_dir } set to { val } " )
else :
skip_folders . append ( current_dir )
multipliers [ current_dir ] = 1.0
except Exception as e :
logging . warning ( f " * { Fore . LIGHTYELLOW_EX } Error trying to read multiply.txt for { current_dir } : { Style . RESET_ALL } { e } " )
skip_folders . append ( current_dir )
multipliers [ current_dir ] = 1.0
2023-01-07 11:57:23 -07:00
if os . path . exists ( yaml_file_path ) :
caption = self . __read_caption_from_yaml ( yaml_file_path , caption )
elif os . path . exists ( txt_file_path ) :
2023-01-07 09:29:09 -07:00
caption = self . __read_caption_from_file ( txt_file_path , caption )
2022-12-17 20:32:48 -07:00
elif os . path . exists ( caption_file_path ) :
2023-01-07 09:29:09 -07:00
caption = self . __read_caption_from_file ( caption_file_path , caption )
2022-12-17 20:32:48 -07:00
2022-12-27 12:25:32 -07:00
try :
image = Image . open ( pathname )
width , height = image . size
image_aspect = width / height
2022-12-17 20:32:48 -07:00
2022-12-27 12:25:32 -07:00
target_wh = min ( self . aspects , key = lambda aspects : abs ( aspects [ 0 ] / aspects [ 1 ] - image_aspect ) )
2023-01-20 07:42:24 -07:00
if not self . has_scanned :
if width * height < target_wh [ 0 ] * target_wh [ 1 ] :
2023-01-20 14:23:56 -07:00
undersized_images . append ( f " { pathname } , size: { width } , { height } , target size: { target_wh } " )
2022-12-17 20:32:48 -07:00
2023-01-21 23:15:50 -07:00
image_train_item = ImageTrainItem ( image = None , # image loaded at runtime to apply jitter
caption = caption ,
target_wh = target_wh ,
pathname = pathname ,
flip_p = flip_p ,
multiplier = multipliers [ current_dir ] ,
)
2022-12-17 20:32:48 -07:00
2022-12-27 12:25:32 -07:00
decorated_image_train_items . append ( image_train_item )
2023-01-20 07:42:24 -07:00
2022-12-27 12:25:32 -07:00
except Exception as e :
logging . error ( f " { Fore . LIGHTRED_EX } *** Error opening { Fore . LIGHTYELLOW_EX } { pathname } { Fore . LIGHTRED_EX } to get metadata. File may be corrupt and will be skipped. { Style . RESET_ALL } " )
logging . error ( f " *** exception: { e } " )
pass
2022-12-17 20:32:48 -07:00
2023-01-20 07:42:24 -07:00
if not self . has_scanned :
self . has_scanned = True
if len ( undersized_images ) > 0 :
underized_log_path = os . path . join ( self . log_folder , " undersized_images.txt " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Some images are smaller than the target size, consider using larger images { Style . RESET_ALL } " )
logging . warning ( f " { Fore . LIGHTRED_EX } ** Check { underized_log_path } for more information. { Style . RESET_ALL } " )
with open ( underized_log_path , " w " ) as undersized_images_file :
undersized_images_file . write ( f " The following images are smaller than the target size, consider removing or sourcing a larger copy: " )
for undersized_image in undersized_images :
2023-01-20 14:23:56 -07:00
undersized_images_file . write ( f " { undersized_image } \n " )
2023-01-20 07:42:24 -07:00
2022-12-17 20:32:48 -07:00
return decorated_image_train_items
2023-01-14 06:00:30 -07:00
def __pick_random_subset ( self , dropout_fraction : float , picker : random . Random ) - > list [ ImageTrainItem ] :
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
Picks a random subset of all images
- The size of the subset is limited by dropout_faction
- The chance of an image to be picked is influenced by its rating . Double that rating - > double the chance
: param dropout_fraction : must be between 0.0 and 1.0
: param picker : seeded random picker
: return : list of picked ImageTrainItem
2022-12-17 20:32:48 -07:00
"""
2023-01-14 06:00:30 -07:00
prepared_train_data = self . prepared_train_data . copy ( )
ratings_summed = self . ratings_summed . copy ( )
rating_overall_sum = self . rating_overall_sum
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
num_images = len ( prepared_train_data )
num_images_to_pick = math . ceil ( num_images * dropout_fraction )
num_images_to_pick = max ( min ( num_images_to_pick , num_images ) , 0 )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# logging.info(f"Picking {num_images_to_pick} images out of the {num_images} in the dataset for drop_fraction {dropout_fraction}")
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
picked_images : list [ ImageTrainItem ] = [ ]
while num_images_to_pick > len ( picked_images ) :
# find random sample in dataset
point = picker . uniform ( 0.0 , rating_overall_sum )
pos = min ( bisect . bisect_left ( ratings_summed , point ) , len ( prepared_train_data ) - 1 )
2023-01-01 08:45:18 -07:00
2023-01-14 06:00:30 -07:00
# pick random sample
picked_image = prepared_train_data [ pos ]
picked_images . append ( picked_image )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
# kick picked item out of data set to not pick it again
rating_overall_sum = max ( rating_overall_sum - picked_image . caption . rating ( ) , 0.0 )
ratings_summed . pop ( pos )
prepared_train_data . pop ( pos )
2022-12-17 20:32:48 -07:00
2023-01-14 06:00:30 -07:00
return picked_images
2022-12-17 20:32:48 -07:00
@staticmethod
def __recurse_data_root ( self , recurse_root ) :
for f in os . listdir ( recurse_root ) :
current = os . path . join ( recurse_root , f )
if os . path . isfile ( current ) :
2023-01-10 03:54:26 -07:00
ext = os . path . splitext ( f ) [ 1 ] . lower ( )
2022-12-18 11:03:44 -07:00
if ext in [ ' .jpg ' , ' .jpeg ' , ' .png ' , ' .bmp ' , ' .webp ' , ' .jfif ' ] :
2023-01-21 23:15:50 -07:00
self . image_paths . append ( current )
2022-12-17 20:32:48 -07:00
sub_dirs = [ ]
for d in os . listdir ( recurse_root ) :
current = os . path . join ( recurse_root , d )
if os . path . isdir ( current ) :
sub_dirs . append ( current )
for dir in sub_dirs :
self . __recurse_data_root ( self = self , recurse_root = dir )