make fractional multiplier logic apply per-directory
This commit is contained in:
parent
a7b00e9ef3
commit
19347bcaa8
|
@ -15,13 +15,13 @@ limitations under the License.
|
||||||
"""
|
"""
|
||||||
import bisect
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
from functools import reduce
|
import os.path
|
||||||
|
from collections import defaultdict
|
||||||
import math
|
import math
|
||||||
import copy
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from data.image_train_item import ImageTrainItem, ImageCaption
|
from data.image_train_item import ImageTrainItem
|
||||||
import PIL
|
import PIL.Image
|
||||||
|
|
||||||
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
|
||||||
|
|
||||||
|
@ -39,9 +39,9 @@ class DataLoaderMultiAspect():
|
||||||
self.prepared_train_data = image_train_items
|
self.prepared_train_data = image_train_items
|
||||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||||
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
expected_epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||||
if self.epoch_size != len(self.prepared_train_data):
|
if expected_epoch_size != len(self.prepared_train_data):
|
||||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.")
|
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {expected_epoch_size} images.")
|
||||||
else:
|
else:
|
||||||
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
||||||
|
|
||||||
|
@ -50,12 +50,12 @@ class DataLoaderMultiAspect():
|
||||||
self.__update_rating_sums()
|
self.__update_rating_sums()
|
||||||
|
|
||||||
|
|
||||||
def __pick_multiplied_set(self, randomizer):
|
def __pick_multiplied_set(self, randomizer: random.Random):
|
||||||
"""
|
"""
|
||||||
Deals with multiply.txt whole and fractional numbers
|
Deals with multiply.txt whole and fractional numbers
|
||||||
"""
|
"""
|
||||||
picked_images = []
|
picked_images = []
|
||||||
fractional_images = []
|
fractional_images_per_directory = defaultdict(list[ImageTrainItem])
|
||||||
for iti in self.prepared_train_data:
|
for iti in self.prepared_train_data:
|
||||||
multiplier = iti.multiplier
|
multiplier = iti.multiplier
|
||||||
while multiplier >= 1:
|
while multiplier >= 1:
|
||||||
|
@ -63,20 +63,15 @@ class DataLoaderMultiAspect():
|
||||||
multiplier -= 1
|
multiplier -= 1
|
||||||
# fractional remainders must be dealt with separately
|
# fractional remainders must be dealt with separately
|
||||||
if multiplier > 0:
|
if multiplier > 0:
|
||||||
fractional_images.append((iti, multiplier))
|
directory = os.path.dirname(iti.pathname)
|
||||||
|
fractional_images_per_directory[directory].append(iti)
|
||||||
|
|
||||||
target_epoch_size = self.epoch_size
|
# resolve fractional parts per-directory
|
||||||
while len(picked_images) < target_epoch_size and len(fractional_images) > 0:
|
for _, fractional_items in fractional_images_per_directory.items():
|
||||||
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier
|
randomizer.shuffle(fractional_items)
|
||||||
iti, multiplier = fractional_images.pop(0)
|
multiplier = fractional_items[0].multiplier % 1.0
|
||||||
if randomizer.uniform(0, 1) < multiplier:
|
count_to_take = math.ceil(multiplier * len(fractional_items))
|
||||||
# shift it over to picked_images
|
picked_images.extend(fractional_items[:count_to_take])
|
||||||
picked_images.append(iti)
|
|
||||||
else:
|
|
||||||
# put it back and move on to the next
|
|
||||||
fractional_images.append((iti, multiplier))
|
|
||||||
|
|
||||||
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
|
|
||||||
|
|
||||||
return picked_images
|
return picked_images
|
||||||
|
|
||||||
|
|
1
train.py
1
train.py
|
@ -550,6 +550,7 @@ def main(args):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.error(" * Failed to load checkpoint *")
|
logging.error(" * Failed to load checkpoint *")
|
||||||
|
raise
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
Loading…
Reference in New Issue