working on colab for auto_caption
This commit is contained in:
parent
54bb8cf87c
commit
53a774e696
File diff suppressed because one or more lines are too long
Binary file not shown.
After Width: | Height: | Size: 9.3 KiB |
|
@ -60,7 +60,7 @@ def get_parser(**parser_kwargs):
|
|||
const=True,
|
||||
default=24,
|
||||
help="adjusts the likelihood of a word being repeated",
|
||||
)
|
||||
),
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -82,27 +82,26 @@ async def main(opt):
|
|||
if opt.nucleus:
|
||||
sample = True
|
||||
|
||||
input_dir = os.path.join(os.getcwd(), opt.img_dir)
|
||||
input_dir = opt.img_dir
|
||||
print("input_dir: ", input_dir)
|
||||
|
||||
config_path = os.path.join(os.getcwd(), "scripts/BLIP/configs/med_config.json")
|
||||
config_path = "scripts/BLIP/configs/med_config.json"
|
||||
|
||||
model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
|
||||
model_path = os.path.join(os.getcwd(), model_cache_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Downloading model to {model_path}... please wait")
|
||||
if not os.path.exists(model_cache_path):
|
||||
print(f"Downloading model to {model_cache_path}... please wait")
|
||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(blip_model_url) as res:
|
||||
result = await res.read()
|
||||
with open(model_path, 'wb') as f:
|
||||
with open(model_cache_path, 'wb') as f:
|
||||
f.write(result)
|
||||
print(f"Model cached to: {model_path}")
|
||||
print(f"Model cached to: {model_cache_path}")
|
||||
else:
|
||||
print(f"Model already cached to: {model_path}")
|
||||
print(f"Model already cached to: {model_cache_path}")
|
||||
|
||||
blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=384, vit='base', med_config=config_path)
|
||||
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
|
||||
blip_decoder.eval()
|
||||
|
||||
print("loading model to cuda")
|
||||
|
@ -159,7 +158,7 @@ def isWindows():
|
|||
return sys.platform.startswith("win")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("starting")
|
||||
print(f"starting in {print(os.getcwd())}")
|
||||
parser = get_parser()
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
@ -172,7 +171,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
print("Unix detected, using default asyncio event loop policy")
|
||||
|
||||
blip_path = os.path.join(os.getcwd(), "scripts/BLIP")
|
||||
blip_path = "scripts/BLIP"
|
||||
sys.path.append(blip_path)
|
||||
|
||||
asyncio.run(main(opt))
|
||||
|
|
Loading…
Reference in New Issue