working on colab for auto_caption
This commit is contained in:
parent
54bb8cf87c
commit
53a774e696
File diff suppressed because one or more lines are too long
Binary file not shown.
After Width: | Height: | Size: 9.3 KiB |
|
@ -60,7 +60,7 @@ def get_parser(**parser_kwargs):
|
||||||
const=True,
|
const=True,
|
||||||
default=24,
|
default=24,
|
||||||
help="adjusts the likelihood of a word being repeated",
|
help="adjusts the likelihood of a word being repeated",
|
||||||
)
|
),
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -82,27 +82,26 @@ async def main(opt):
|
||||||
if opt.nucleus:
|
if opt.nucleus:
|
||||||
sample = True
|
sample = True
|
||||||
|
|
||||||
input_dir = os.path.join(os.getcwd(), opt.img_dir)
|
input_dir = opt.img_dir
|
||||||
print("input_dir: ", input_dir)
|
print("input_dir: ", input_dir)
|
||||||
|
|
||||||
config_path = os.path.join(os.getcwd(), "scripts/BLIP/configs/med_config.json")
|
config_path = "scripts/BLIP/configs/med_config.json"
|
||||||
|
|
||||||
model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
|
model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
|
||||||
model_path = os.path.join(os.getcwd(), model_cache_path)
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_cache_path):
|
||||||
print(f"Downloading model to {model_path}... please wait")
|
print(f"Downloading model to {model_cache_path}... please wait")
|
||||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(blip_model_url) as res:
|
async with session.get(blip_model_url) as res:
|
||||||
result = await res.read()
|
result = await res.read()
|
||||||
with open(model_path, 'wb') as f:
|
with open(model_cache_path, 'wb') as f:
|
||||||
f.write(result)
|
f.write(result)
|
||||||
print(f"Model cached to: {model_path}")
|
print(f"Model cached to: {model_cache_path}")
|
||||||
else:
|
else:
|
||||||
print(f"Model already cached to: {model_path}")
|
print(f"Model already cached to: {model_cache_path}")
|
||||||
|
|
||||||
blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=384, vit='base', med_config=config_path)
|
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
|
||||||
blip_decoder.eval()
|
blip_decoder.eval()
|
||||||
|
|
||||||
print("loading model to cuda")
|
print("loading model to cuda")
|
||||||
|
@ -159,7 +158,7 @@ def isWindows():
|
||||||
return sys.platform.startswith("win")
|
return sys.platform.startswith("win")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("starting")
|
print(f"starting in {print(os.getcwd())}")
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -172,7 +171,7 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
print("Unix detected, using default asyncio event loop policy")
|
print("Unix detected, using default asyncio event loop policy")
|
||||||
|
|
||||||
blip_path = os.path.join(os.getcwd(), "scripts/BLIP")
|
blip_path = "scripts/BLIP"
|
||||||
sys.path.append(blip_path)
|
sys.path.append(blip_path)
|
||||||
|
|
||||||
asyncio.run(main(opt))
|
asyncio.run(main(opt))
|
||||||
|
|
Loading…
Reference in New Issue