add res to dataloader test script
This commit is contained in:
parent
8211d11329
commit
6a22fa7594
|
@ -7,7 +7,7 @@ import argparse
|
||||||
import ldm.data.data_loader as dl
|
import ldm.data.data_loader as dl
|
||||||
|
|
||||||
def main(data_root, batch_size):
|
def main(data_root, batch_size):
|
||||||
data_loader = dl.DataLoaderMultiAspect(data_root=data_root, batch_size=batch_size, debug_level=1)
|
data_loader = dl.DataLoaderMultiAspect(data_root=data_root, batch_size=batch_size, debug_level=1, resolution=512)
|
||||||
|
|
||||||
image_caption_pairs = data_loader.get_all_images()
|
image_caption_pairs = data_loader.get_all_images()
|
||||||
|
|
||||||
|
@ -25,5 +25,6 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--data_root", type=str, default="input", help="root folder of training data")
|
parser.add_argument("--data_root", type=str, default="input", help="root folder of training data")
|
||||||
parser.add_argument("--batch_size", type=int, default=4, help="number of images per batch")
|
parser.add_argument("--batch_size", type=int, default=4, help="number of images per batch")
|
||||||
|
parser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=[512, 576, 640, 704, 768])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(data_root=args.data_root, batch_size=args.batch_size)
|
main(data_root=args.data_root, batch_size=args.batch_size, resolution=args.resolution)
|
Loading…
Reference in New Issue