make fractional multiplier logic apply per-directory
This commit is contained in:
parent
a7b00e9ef3
commit
19347bcaa8
|
@ -15,13 +15,13 @@ limitations under the License.
|
|||
"""
|
||||
import bisect
|
||||
import logging
|
||||
from functools import reduce
|
||||
import os.path
|
||||
from collections import defaultdict
|
||||
import math
|
||||
import copy
|
||||
|
||||
import random
|
||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
||||
import PIL
|
||||
from data.image_train_item import ImageTrainItem
|
||||
import PIL.Image
|
||||
|
||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||
|
||||
|
@ -39,9 +39,9 @@ class DataLoaderMultiAspect():
|
|||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||
if self.epoch_size != len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.")
|
||||
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||
if expected_epoch_size != len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
|
||||
else:
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
||||
|
||||
|
@ -50,12 +50,12 @@ class DataLoaderMultiAspect():
|
|||
self.__update_rating_sums()
|
||||
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
def __pick_multiplied_set(self, randomizer: random.Random):
|
||||
"""
|
||||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
picked_images = []
|
||||
fractional_images = []
|
||||
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
|
||||
for iti in self.prepared_train_data:
|
||||
multiplier = iti.multiplier
|
||||
while multiplier >= 1:
|
||||
|
@ -63,20 +63,15 @@ class DataLoaderMultiAspect():
|
|||
multiplier -= 1
|
||||
# fractional remainders must be dealt with separately
|
||||
if multiplier > 0:
|
||||
fractional_images.append((iti, multiplier))
|
||||
directory = os.path.dirname(iti.pathname)
|
||||
fractional_images_per_directory[directory].append(iti)
|
||||
|
||||
target_epoch_size = self.epoch_size
|
||||
while len(picked_images) < target_epoch_size and len(fractional_images) > 0:
|
||||
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier
|
||||
iti, multiplier = fractional_images.pop(0)
|
||||
if randomizer.uniform(0, 1) < multiplier:
|
||||
# shift it over to picked_images
|
||||
picked_images.append(iti)
|
||||
else:
|
||||
# put it back and move on to the next
|
||||
fractional_images.append((iti, multiplier))
|
||||
|
||||
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
|
||||
# resolve fractional parts per-directory
|
||||
for _, fractional_items in fractional_images_per_directory.items():
|
||||
randomizer.shuffle(fractional_items)
|
||||
multiplier = fractional_items[0].multiplier % 1.0
|
||||
count_to_take = math.ceil(multiplier * len(fractional_items))
|
||||
picked_images.extend(fractional_items[:count_to_take])
|
||||
|
||||
return picked_images
|
||||
|
||||
|
|
Loading…
Reference in New Issue