fix multiplier logic

This commit is contained in:
Damian Stewart 2023-02-08 13:46:58 +01:00
parent 4e37200dda
commit a7b00e9ef3
2 changed files with 51 additions and 34 deletions

View File

@ -39,14 +39,15 @@ class DataLoaderMultiAspect():
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
if self.epoch_size != len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
self.rating_overall_sum: float = 0.0
self.ratings_summed: list[float] = []
self.__update_rating_sums()
count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data])
if count_including_multipliers > len(self.prepared_train_data):
logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)")
else:
logging.info(f" * DLMA initialized with {len(image_train_items)} items")
def __pick_multiplied_set(self, randomizer):
@ -54,14 +55,28 @@ class DataLoaderMultiAspect():
Deals with multiply.txt whole and fractional numbers
"""
picked_images = []
fractional_images = []
for iti in self.prepared_train_data:
multiplier = iti.multiplier
while multiplier >= 1:
picked_images.append(iti)
multiplier -= 1
# deal with fractional remainder
if multiplier > randomizer.uniform(0, 1):
# fractional remainders must be dealt with separately
if multiplier > 0:
fractional_images.append((iti, multiplier))
target_epoch_size = self.epoch_size
while len(picked_images) < target_epoch_size and len(fractional_images) > 0:
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier
iti, multiplier = fractional_images.pop(0)
if randomizer.uniform(0, 1) < multiplier:
# shift it over to picked_images
picked_images.append(iti)
else:
# put it back and move on to the next
fractional_images.append((iti, multiplier))
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
return picked_images
@ -98,20 +113,19 @@ class DataLoaderMultiAspect():
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
runt_bucket = buckets[bucket][-truncate_count:]
for item in runt_bucket:
item.runt_size = truncate_count
while len(runt_bucket) < batch_size:
runt_bucket.append(random.choice(runt_bucket))
current_bucket_size = len(buckets[bucket])
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
items: list[ImageTrainItem] = []

View File

@ -160,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
"""
updates the vram usage for the epoch
"""
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if gpu is not None:
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
epoch_mem_color = Style.RESET_ALL
if gpu_used_mem > 0.93 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTRED_EX
elif gpu_used_mem > 0.85 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTYELLOW_EX
elif gpu_used_mem > 0.7 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTGREEN_EX
elif gpu_used_mem < 0.5 * gpu_total_mem:
epoch_mem_color = Fore.LIGHTBLUE_EX
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args):
@ -372,6 +373,7 @@ def main(args):
else:
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
@ -714,8 +716,9 @@ def main(args):
if not os.path.exists(f"{log_folder}/samples/"):
os.makedirs(f"{log_folder}/samples/")
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
if gpu is not None:
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")