diff --git a/scripts/auto_caption.py b/scripts/auto_caption.py index 015e99e..b5c2f8e 100644 --- a/scripts/auto_caption.py +++ b/scripts/auto_caption.py @@ -66,6 +66,14 @@ def get_parser(**parser_kwargs): default=22, help="adjusts the likelihood of a word being repeated", ), + parser.add_argument( + "--torch_device", + type=str, + nargs="?", + const=False, + default="cuda", + help="specify a different torch device, e.g. 'cpu'", + ), return parser @@ -100,13 +108,13 @@ async def main(opt): if not os.path.exists(cache_folder): os.makedirs(cache_folder) - + if not os.path.exists(opt.out_dir): os.makedirs(opt.out_dir) if not os.path.exists(model_cache_path): print(f"Downloading model to {model_cache_path}... please wait") - + async with aiohttp.ClientSession() as session: async with session.get(BLIP_MODEL_URL) as res: with open(model_cache_path, 'wb') as f: @@ -119,9 +127,9 @@ async def main(opt): blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path) blip_decoder.eval() - print("loading model to cuda") + print(f"loading model to {opt.torch_device}") - blip_decoder = blip_decoder.to(torch.device("cuda")) + blip_decoder = blip_decoder.to(torch.device(opt.torch_device)) ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif') @@ -141,7 +149,7 @@ async def main(opt): if not image.mode == "RGB": image = image.convert("RGB") - image = load_image(image, device=torch.device("cuda")) + image = load_image(image, device=torch.device(opt.torch_device)) if opt.nucleus: captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor) @@ -193,8 +201,8 @@ if __name__ == "__main__": if opt.format not in ["filename", "mrwho", "joepenna", "txt", "text", "caption"]: raise ValueError("format must be 'filename', 'mrwho', 'txt', or 'caption'") - - if (isWindows()): + + if (isWindows()): print("Windows detected, using asyncio.WindowsSelectorEventLoopPolicy") asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) else: @@ -207,4 +215,3 @@ if __name__ == "__main__": sys.path.append(blip_path) asyncio.run(main(opt)) - \ No newline at end of file