Merge pull request #81 from damian0815/fix-validation-div-by-zero

Fix validation division by zero
This commit is contained in:
Victor Hall 2023-02-19 05:46:43 -08:00 committed by GitHub
commit eec363899e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 13 deletions

View File

@ -23,7 +23,7 @@ from utils.isolate_rng import isolate_rng
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
split_item_count = math.ceil(split_proportion * len(items) / batch_size) * batch_size
# sort first, then shuffle, to ensure determinate outcome for the current random state
items_copy = list(sorted(items, key=lambda i: i.pathname))
random.shuffle(items_copy)

View File

@ -711,6 +711,10 @@ def main(args):
return model_pred, target
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation_if_appropriate(epoch=0, global_step=0,
get_model_prediction_and_target_callable=get_model_prediction_and_target)
try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
@ -849,7 +853,7 @@ def main(args):
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
if validator:
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
gc.collect()
# end of epoch

View File

@ -8,16 +8,21 @@ from typing import Optional
from tqdm.auto import tqdm
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
CAPTION_EXTENSIONS = ['.txt', '.caption', '.yaml', '.yml']
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
for directory, _, filenames in os.walk(root_dir):
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
for image_filename in image_filenames:
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
image_path = os.path.join(directory+image_filename)
caption_path = os.path.join(directory+caption_filename)
yield (image_path, caption_path if os.path.exists(caption_path) else None)
image_path = os.path.join(directory, image_filename)
image_path_without_extension = os.path.splitext(image_path)[0]
caption_path = None
for caption_extension in CAPTION_EXTENSIONS:
possible_caption_path = image_path_without_extension + caption_extension
if os.path.exists(possible_caption_path):
caption_path = possible_caption_path
break
yield image_path, caption_path
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
@ -39,13 +44,13 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('source_root', type=str, help='Source root folder containing images')
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
parser.add_argument('--train_output_folder', type=str, required=False, help="Output folder for the 'train' dataset. If omitted, do not save the train split.")
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset.")
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
args = parser.parse_args()
images = gather_captioned_images(args.source_root)
images = list(gather_captioned_images(args.source_root))
print(f"Found {len(images)} captioned images in {args.source_root}")
val_split_count = math.ceil(len(images) * args.split_proportion)
if val_split_count == 0:
@ -59,7 +64,9 @@ if __name__ == '__main__':
print(f"Copying 'val' set to {args.val_output_folder}...")
for v in tqdm(val_split):
copy_captioned_image(v, args.source_root, args.val_output_folder)
print(f"Copying 'train' set to {args.train_output_folder}...")
for v in tqdm(train_split):
copy_captioned_image(v, args.source_root, args.train_output_folder)
if args.train_output_folder is not None:
print(f"Copying 'train' set to {args.train_output_folder}...")
for v in tqdm(train_split):
copy_captioned_image(v, args.source_root, args.train_output_folder)
print("Done.")