working on colab for auto_caption

This commit is contained in:
Victor Hall 2022-10-31 00:02:10 -04:00
parent 54bb8cf87c
commit 53a774e696
3 changed files with 12 additions and 13 deletions

File diff suppressed because one or more lines are too long

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.3 KiB

View File

@ -60,7 +60,7 @@ def get_parser(**parser_kwargs):
const=True, const=True,
default=24, default=24,
help="adjusts the likelihood of a word being repeated", help="adjusts the likelihood of a word being repeated",
) ),
return parser return parser
@ -82,27 +82,26 @@ async def main(opt):
if opt.nucleus: if opt.nucleus:
sample = True sample = True
input_dir = os.path.join(os.getcwd(), opt.img_dir) input_dir = opt.img_dir
print("input_dir: ", input_dir) print("input_dir: ", input_dir)
config_path = os.path.join(os.getcwd(), "scripts/BLIP/configs/med_config.json") config_path = "scripts/BLIP/configs/med_config.json"
model_cache_path = ".cache/model_base_caption_capfilt_large.pth" model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
model_path = os.path.join(os.getcwd(), model_cache_path)
if not os.path.exists(model_path): if not os.path.exists(model_cache_path):
print(f"Downloading model to {model_path}... please wait") print(f"Downloading model to {model_cache_path}... please wait")
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(blip_model_url) as res: async with session.get(blip_model_url) as res:
result = await res.read() result = await res.read()
with open(model_path, 'wb') as f: with open(model_cache_path, 'wb') as f:
f.write(result) f.write(result)
print(f"Model cached to: {model_path}") print(f"Model cached to: {model_cache_path}")
else: else:
print(f"Model already cached to: {model_path}") print(f"Model already cached to: {model_cache_path}")
blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=384, vit='base', med_config=config_path) blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
blip_decoder.eval() blip_decoder.eval()
print("loading model to cuda") print("loading model to cuda")
@ -159,7 +158,7 @@ def isWindows():
return sys.platform.startswith("win") return sys.platform.startswith("win")
if __name__ == "__main__": if __name__ == "__main__":
print("starting") print(f"starting in {print(os.getcwd())}")
parser = get_parser() parser = get_parser()
opt = parser.parse_args() opt = parser.parse_args()
@ -172,7 +171,7 @@ if __name__ == "__main__":
else: else:
print("Unix detected, using default asyncio event loop policy") print("Unix detected, using default asyncio event loop policy")
blip_path = os.path.join(os.getcwd(), "scripts/BLIP") blip_path = "scripts/BLIP"
sys.path.append(blip_path) sys.path.append(blip_path)
asyncio.run(main(opt)) asyncio.run(main(opt))