2023-02-06 23:10:34 -07:00
import argparse
import math
import os . path
import random
import shutil
from typing import Optional
from tqdm . auto import tqdm
IMAGE_EXTENSIONS = [ ' .jpg ' , ' .jpeg ' , ' .png ' , ' .bmp ' , ' .webp ' , ' .jfif ' ]
2023-02-17 11:35:25 -07:00
CAPTION_EXTENSIONS = [ ' .txt ' , ' .caption ' , ' .yaml ' , ' .yml ' ]
2023-02-06 23:10:34 -07:00
def gather_captioned_images ( root_dir : str ) - > list [ tuple [ str , Optional [ str ] ] ] :
for directory , _ , filenames in os . walk ( root_dir ) :
image_filenames = [ f for f in filenames if os . path . splitext ( f ) [ 1 ] . lower ( ) in IMAGE_EXTENSIONS ]
for image_filename in image_filenames :
2023-02-17 11:35:25 -07:00
image_path = os . path . join ( directory , image_filename )
image_path_without_extension = os . path . splitext ( image_path ) [ 0 ]
caption_path = None
for caption_extension in CAPTION_EXTENSIONS :
possible_caption_path = image_path_without_extension + caption_extension
if os . path . exists ( possible_caption_path ) :
caption_path = possible_caption_path
break
yield image_path , caption_path
2023-02-06 23:10:34 -07:00
def copy_captioned_image ( image_caption_pair : tuple [ str , Optional [ str ] ] , source_root : str , target_root : str ) :
image_path = image_caption_pair [ 0 ]
caption_path = image_caption_pair [ 1 ]
# make target folder if necessary
relative_folder = os . path . dirname ( os . path . relpath ( image_path , source_root ) )
target_folder = os . path . join ( target_root , relative_folder )
os . makedirs ( target_folder , exist_ok = True )
# copy files
shutil . copy2 ( image_path , os . path . join ( target_folder , os . path . basename ( image_path ) ) )
if caption_path is not None :
shutil . copy2 ( caption_path , os . path . join ( target_folder , os . path . basename ( caption_path ) ) )
if __name__ == ' __main__ ' :
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' source_root ' , type = str , help = ' Source root folder containing images ' )
2023-02-17 11:35:25 -07:00
parser . add_argument ( ' --train_output_folder ' , type = str , required = False , help = " Output folder for the ' train ' dataset. If omitted, do not save the train split. " )
parser . add_argument ( ' --val_output_folder ' , type = str , required = True , help = " Output folder for the ' val ' dataset. " )
2023-02-06 23:10:34 -07:00
parser . add_argument ( ' --split_proportion ' , type = float , required = True , help = " Proportion of images to use for ' val ' (a number between 0 and 1) " )
parser . add_argument ( ' --seed ' , type = int , required = False , default = 555 , help = ' Random seed for shuffling ' )
args = parser . parse_args ( )
2023-02-17 11:35:25 -07:00
images = list ( gather_captioned_images ( args . source_root ) )
2023-02-06 23:10:34 -07:00
print ( f " Found { len ( images ) } captioned images in { args . source_root } " )
val_split_count = math . ceil ( len ( images ) * args . split_proportion )
if val_split_count == 0 :
raise ValueError ( f " No images in validation split with source count { len ( images ) } and split proportion { args . split_proportion } " )
random . seed ( args . seed )
random . shuffle ( images )
val_split = images [ 0 : val_split_count ]
train_split = images [ val_split_count : ]
print ( f " Split to ' train ' set with { len ( train_split ) } images and ' val ' set with { len ( val_split ) } " )
print ( f " Copying ' val ' set to { args . val_output_folder } ... " )
for v in tqdm ( val_split ) :
copy_captioned_image ( v , args . source_root , args . val_output_folder )
2023-02-17 11:35:25 -07:00
if args . train_output_folder is not None :
print ( f " Copying ' train ' set to { args . train_output_folder } ... " )
for v in tqdm ( train_split ) :
copy_captioned_image ( v , args . source_root , args . train_output_folder )
2023-02-06 23:10:34 -07:00
print ( " Done. " )