make fractional multiplier logic apply per-directory

This commit is contained in:
Damian Stewart 2023-02-08 14:15:54 +01:00
parent a7b00e9ef3
commit 19347bcaa8
2 changed files with 18 additions and 22 deletions

View File

@ -15,13 +15,13 @@ limitations under the License.
""" """
import bisect import bisect
import logging import logging
from functools import reduce import os.path
from collections import defaultdict
import math import math
import copy
import random import random
from data.image_train_item import ImageTrainItem, ImageCaption from data.image_train_item import ImageTrainItem
import PIL import PIL.Image
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
@ -39,9 +39,9 @@ class DataLoaderMultiAspect():
self.prepared_train_data = image_train_items self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data) random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating()) self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data])) expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if self.epoch_size != len(self.prepared_train_data): if expected_epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.") logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
else: else:
logging.info(f" * DLMA initialized with {len(image_train_items)} images.") logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
@ -50,12 +50,12 @@ class DataLoaderMultiAspect():
self.__update_rating_sums() self.__update_rating_sums()
def __pick_multiplied_set(self, randomizer): def __pick_multiplied_set(self, randomizer: random.Random):
""" """
Deals with multiply.txt whole and fractional numbers Deals with multiply.txt whole and fractional numbers
""" """
picked_images = [] picked_images = []
fractional_images = [] fractional_images_per_directory = defaultdict(list[ImageTrainItem])
for iti in self.prepared_train_data: for iti in self.prepared_train_data:
multiplier = iti.multiplier multiplier = iti.multiplier
while multiplier >= 1: while multiplier >= 1:
@ -63,20 +63,15 @@ class DataLoaderMultiAspect():
multiplier -= 1 multiplier -= 1
# fractional remainders must be dealt with separately # fractional remainders must be dealt with separately
if multiplier > 0: if multiplier > 0:
fractional_images.append((iti, multiplier)) directory = os.path.dirname(iti.pathname)
fractional_images_per_directory[directory].append(iti)
target_epoch_size = self.epoch_size # resolve fractional parts per-directory
while len(picked_images) < target_epoch_size and len(fractional_images) > 0: for _, fractional_items in fractional_images_per_directory.items():
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier randomizer.shuffle(fractional_items)
iti, multiplier = fractional_images.pop(0) multiplier = fractional_items[0].multiplier % 1.0
if randomizer.uniform(0, 1) < multiplier: count_to_take = math.ceil(multiplier * len(fractional_items))
# shift it over to picked_images picked_images.extend(fractional_items[:count_to_take])
picked_images.append(iti)
else:
# put it back and move on to the next
fractional_images.append((iti, multiplier))
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
return picked_images return picked_images

View File

@ -550,6 +550,7 @@ def main(args):
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
logging.error(" * Failed to load checkpoint *") logging.error(" * Failed to load checkpoint *")
raise
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()