Add torch_device option to scripts/auto_caption.py

This allows using auto_caption on apple sillicon macs by specifying cpu
as an argument for now, and might allow using mps eventually, once more
operators are implemented.
This commit is contained in:
Victor Hall 2023-01-09 22:46:40 -05:00 committed by Devin Lehmacher
parent fd67f7ad7f
commit da15f0a745
1 changed files with 15 additions and 8 deletions

View File

@ -66,6 +66,14 @@ def get_parser(**parser_kwargs):
default=22, default=22,
help="adjusts the likelihood of a word being repeated", help="adjusts the likelihood of a word being repeated",
), ),
parser.add_argument(
"--torch_device",
type=str,
nargs="?",
const=False,
default="cuda",
help="specify a different torch device, e.g. 'cpu'",
),
return parser return parser
@ -119,9 +127,9 @@ async def main(opt):
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path) blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
blip_decoder.eval() blip_decoder.eval()
print("loading model to cuda") print(f"loading model to {opt.torch_device}")
blip_decoder = blip_decoder.to(torch.device("cuda")) blip_decoder = blip_decoder.to(torch.device(opt.torch_device))
ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif') ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
@ -141,7 +149,7 @@ async def main(opt):
if not image.mode == "RGB": if not image.mode == "RGB":
image = image.convert("RGB") image = image.convert("RGB")
image = load_image(image, device=torch.device("cuda")) image = load_image(image, device=torch.device(opt.torch_device))
if opt.nucleus: if opt.nucleus:
captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor) captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
@ -207,4 +215,3 @@ if __name__ == "__main__":
sys.path.append(blip_path) sys.path.append(blip_path)
asyncio.run(main(opt)) asyncio.run(main(opt))