make fractional multiplier logic apply per-directory

This commit is contained in:
Damian Stewart 2023-02-08 14:15:54 +01:00
parent a7b00e9ef3
commit 19347bcaa8
2 changed files with 18 additions and 22 deletions

View File

@ -15,13 +15,13 @@ limitations under the License.
"""
import bisect
import logging
from functools import reduce
import os.path
from collections import defaultdict
import math
import copy
import random
from data.image_train_item import ImageTrainItem, ImageCaption
import PIL
from data.image_train_item import ImageTrainItem
import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -39,9 +39,9 @@ class DataLoaderMultiAspect():
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if self.epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.")
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if expected_epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
@ -50,12 +50,12 @@ class DataLoaderMultiAspect():
self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer):
def __pick_multiplied_set(self, randomizer: random.Random):
"""
Deals with multiply.txt whole and fractional numbers
"""
picked_images = []
fractional_images = []
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
for iti in self.prepared_train_data:
multiplier = iti.multiplier
while multiplier >= 1:
@ -63,20 +63,15 @@ class DataLoaderMultiAspect():
multiplier -= 1
# fractional remainders must be dealt with separately
if multiplier > 0:
fractional_images.append((iti, multiplier))
directory = os.path.dirname(iti.pathname)
fractional_images_per_directory[directory].append(iti)
target_epoch_size = self.epoch_size
while len(picked_images) < target_epoch_size and len(fractional_images) > 0:
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier
iti, multiplier = fractional_images.pop(0)
if randomizer.uniform(0, 1) < multiplier:
# shift it over to picked_images
picked_images.append(iti)
else:
# put it back and move on to the next
fractional_images.append((iti, multiplier))
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
# resolve fractional parts per-directory
for _, fractional_items in fractional_images_per_directory.items():
randomizer.shuffle(fractional_items)
multiplier = fractional_items[0].multiplier % 1.0
count_to_take = math.ceil(multiplier * len(fractional_items))
picked_images.extend(fractional_items[:count_to_take])
return picked_images

View File

@ -550,6 +550,7 @@ def main(args):
except Exception as e:
traceback.print_exc()
logging.error(" * Failed to load checkpoint *")
raise
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()