diff --git a/test/test_dl.py b/test/test_dl.py index d0de4f0..527030d 100644 --- a/test/test_dl.py +++ b/test/test_dl.py @@ -7,7 +7,7 @@ import argparse import ldm.data.data_loader as dl def main(data_root, batch_size): - data_loader = dl.DataLoaderMultiAspect(data_root=data_root, batch_size=batch_size, debug_level=1) + data_loader = dl.DataLoaderMultiAspect(data_root=data_root, batch_size=batch_size, debug_level=1, resolution=512) image_caption_pairs = data_loader.get_all_images() @@ -25,5 +25,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_root", type=str, default="input", help="root folder of training data") parser.add_argument("--batch_size", type=int, default=4, help="number of images per batch") + parser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=[512, 576, 640, 704, 768]) args = parser.parse_args() - main(data_root=args.data_root, batch_size=args.batch_size) \ No newline at end of file + main(data_root=args.data_root, batch_size=args.batch_size, resolution=args.resolution) \ No newline at end of file