Add torch_device option to scripts/auto_caption.py
This allows using auto_caption on apple sillicon macs by specifying cpu as an argument for now, and might allow using mps eventually, once more operators are implemented.
This commit is contained in:
parent
fd67f7ad7f
commit
da15f0a745
|
@ -66,6 +66,14 @@ def get_parser(**parser_kwargs):
|
||||||
default=22,
|
default=22,
|
||||||
help="adjusts the likelihood of a word being repeated",
|
help="adjusts the likelihood of a word being repeated",
|
||||||
),
|
),
|
||||||
|
parser.add_argument(
|
||||||
|
"--torch_device",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
const=False,
|
||||||
|
default="cuda",
|
||||||
|
help="specify a different torch device, e.g. 'cpu'",
|
||||||
|
),
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -100,13 +108,13 @@ async def main(opt):
|
||||||
|
|
||||||
if not os.path.exists(cache_folder):
|
if not os.path.exists(cache_folder):
|
||||||
os.makedirs(cache_folder)
|
os.makedirs(cache_folder)
|
||||||
|
|
||||||
if not os.path.exists(opt.out_dir):
|
if not os.path.exists(opt.out_dir):
|
||||||
os.makedirs(opt.out_dir)
|
os.makedirs(opt.out_dir)
|
||||||
|
|
||||||
if not os.path.exists(model_cache_path):
|
if not os.path.exists(model_cache_path):
|
||||||
print(f"Downloading model to {model_cache_path}... please wait")
|
print(f"Downloading model to {model_cache_path}... please wait")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(BLIP_MODEL_URL) as res:
|
async with session.get(BLIP_MODEL_URL) as res:
|
||||||
with open(model_cache_path, 'wb') as f:
|
with open(model_cache_path, 'wb') as f:
|
||||||
|
@ -119,9 +127,9 @@ async def main(opt):
|
||||||
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
|
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
|
||||||
blip_decoder.eval()
|
blip_decoder.eval()
|
||||||
|
|
||||||
print("loading model to cuda")
|
print(f"loading model to {opt.torch_device}")
|
||||||
|
|
||||||
blip_decoder = blip_decoder.to(torch.device("cuda"))
|
blip_decoder = blip_decoder.to(torch.device(opt.torch_device))
|
||||||
|
|
||||||
ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
|
ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
|
||||||
|
|
||||||
|
@ -141,7 +149,7 @@ async def main(opt):
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
image = load_image(image, device=torch.device("cuda"))
|
image = load_image(image, device=torch.device(opt.torch_device))
|
||||||
|
|
||||||
if opt.nucleus:
|
if opt.nucleus:
|
||||||
captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
|
captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
|
||||||
|
@ -193,8 +201,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if opt.format not in ["filename", "mrwho", "joepenna", "txt", "text", "caption"]:
|
if opt.format not in ["filename", "mrwho", "joepenna", "txt", "text", "caption"]:
|
||||||
raise ValueError("format must be 'filename', 'mrwho', 'txt', or 'caption'")
|
raise ValueError("format must be 'filename', 'mrwho', 'txt', or 'caption'")
|
||||||
|
|
||||||
if (isWindows()):
|
if (isWindows()):
|
||||||
print("Windows detected, using asyncio.WindowsSelectorEventLoopPolicy")
|
print("Windows detected, using asyncio.WindowsSelectorEventLoopPolicy")
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
else:
|
else:
|
||||||
|
@ -207,4 +215,3 @@ if __name__ == "__main__":
|
||||||
sys.path.append(blip_path)
|
sys.path.append(blip_path)
|
||||||
|
|
||||||
asyncio.run(main(opt))
|
asyncio.run(main(opt))
|
||||||
|
|
Loading…
Reference in New Issue