fix multiplier logic
This commit is contained in:
parent
4e37200dda
commit
a7b00e9ef3
|
@ -39,14 +39,15 @@ class DataLoaderMultiAspect():
|
|||
self.prepared_train_data = image_train_items
|
||||
random.Random(self.seed).shuffle(self.prepared_train_data)
|
||||
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
|
||||
self.epoch_size = math.floor(sum([i.multiplier for i in self.prepared_train_data]))
|
||||
if self.epoch_size != len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} source images. After applying multipliers, each epoch will train on at least {self.epoch_size} images.")
|
||||
else:
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} images.")
|
||||
|
||||
self.rating_overall_sum: float = 0.0
|
||||
self.ratings_summed: list[float] = []
|
||||
self.__update_rating_sums()
|
||||
count_including_multipliers = sum([math.floor(max(i.multiplier, 1)) for i in self.prepared_train_data])
|
||||
if count_including_multipliers > len(self.prepared_train_data):
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} items ({count_including_multipliers} items total after applying multipliers)")
|
||||
else:
|
||||
logging.info(f" * DLMA initialized with {len(image_train_items)} items")
|
||||
|
||||
|
||||
def __pick_multiplied_set(self, randomizer):
|
||||
|
@ -54,14 +55,28 @@ class DataLoaderMultiAspect():
|
|||
Deals with multiply.txt whole and fractional numbers
|
||||
"""
|
||||
picked_images = []
|
||||
fractional_images = []
|
||||
for iti in self.prepared_train_data:
|
||||
multiplier = iti.multiplier
|
||||
while multiplier >= 1:
|
||||
picked_images.append(iti)
|
||||
multiplier -= 1
|
||||
# deal with fractional remainder
|
||||
if multiplier > randomizer.uniform(0, 1):
|
||||
# fractional remainders must be dealt with separately
|
||||
if multiplier > 0:
|
||||
fractional_images.append((iti, multiplier))
|
||||
|
||||
target_epoch_size = self.epoch_size
|
||||
while len(picked_images) < target_epoch_size and len(fractional_images) > 0:
|
||||
# cycle through fractional_images, randomly shifting each over to picked_images based on its multiplier
|
||||
iti, multiplier = fractional_images.pop(0)
|
||||
if randomizer.uniform(0, 1) < multiplier:
|
||||
# shift it over to picked_images
|
||||
picked_images.append(iti)
|
||||
else:
|
||||
# put it back and move on to the next
|
||||
fractional_images.append((iti, multiplier))
|
||||
|
||||
assert len(picked_images) == target_epoch_size, "Something went wrong while attempting to apply multipliers"
|
||||
|
||||
return picked_images
|
||||
|
||||
|
@ -98,20 +113,19 @@ class DataLoaderMultiAspect():
|
|||
buckets[(target_wh[0],target_wh[1])] = []
|
||||
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
|
||||
|
||||
if len(buckets) > 1:
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
for bucket in buckets:
|
||||
truncate_count = len(buckets[bucket]) % batch_size
|
||||
if truncate_count > 0:
|
||||
runt_bucket = buckets[bucket][-truncate_count:]
|
||||
for item in runt_bucket:
|
||||
item.runt_size = truncate_count
|
||||
while len(runt_bucket) < batch_size:
|
||||
runt_bucket.append(random.choice(runt_bucket))
|
||||
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
current_bucket_size = len(buckets[bucket])
|
||||
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
|
||||
buckets[bucket].extend(runt_bucket)
|
||||
|
||||
# flatten the buckets
|
||||
items: list[ImageTrainItem] = []
|
||||
|
|
33
train.py
33
train.py
|
@ -160,20 +160,21 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
"""
|
||||
updates the vram usage for the epoch
|
||||
"""
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
if gpu is not None:
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
log_writer.add_scalar("performance/vram", gpu_used_mem, global_step)
|
||||
epoch_mem_color = Style.RESET_ALL
|
||||
if gpu_used_mem > 0.93 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTRED_EX
|
||||
elif gpu_used_mem > 0.85 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTYELLOW_EX
|
||||
elif gpu_used_mem > 0.7 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTGREEN_EX
|
||||
elif gpu_used_mem < 0.5 * gpu_total_mem:
|
||||
epoch_mem_color = Fore.LIGHTBLUE_EX
|
||||
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
|
@ -372,6 +373,7 @@ def main(args):
|
|||
else:
|
||||
logging.warning("*** Running on CPU. This is for testing loading/config parsing code only.")
|
||||
device = 'cpu'
|
||||
gpu = None
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
|
@ -714,8 +716,9 @@ def main(args):
|
|||
if not os.path.exists(f"{log_folder}/samples/"):
|
||||
os.makedirs(f"{log_folder}/samples/")
|
||||
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||
if gpu is not None:
|
||||
gpu_used_mem, gpu_total_mem = gpu.get_gpu_memory()
|
||||
logging.info(f" Pretraining GPU Memory: {gpu_used_mem} / {gpu_total_mem} MB")
|
||||
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
|
||||
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
|
||||
|
||||
|
|
Loading…
Reference in New Issue