diff --git a/scripts/auto_caption.py b/scripts/auto_caption.py index 7833aef..4b3daf5 100644 --- a/scripts/auto_caption.py +++ b/scripts/auto_caption.py @@ -11,6 +11,7 @@ import asyncio import subprocess import numpy as np import io +import aiofiles SIZE = 384 BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' @@ -39,7 +40,7 @@ def get_parser(**parser_kwargs): nargs="?", const=True, default="filename", - help="'filename', 'json', or 'parquet'", + help="'filename', 'mrwho', 'txt', or 'caption'", ), parser.add_argument( "--nucleus", @@ -115,7 +116,7 @@ async def main(opt): else: print(f"Model already cached to: {model_cache_path}") - blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path) + blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path) blip_decoder.eval() print("loading model to cuda") @@ -131,13 +132,13 @@ async def main(opt): caption = None file_ext = os.path.splitext(img_file_name)[1] if (file_ext in ext): - with open(img_file_name, "rb") as input_file: + async with aiofiles.open(img_file_name, "rb") as input_file: print("working image: ", img_file_name) - image = Image.open(input_file) + image_bin = await input_file.read() + image = Image.open(io.BytesIO(image_bin)) if not image.mode == "RGB": - print("converting to RGB") image = image.convert("RGB") image = load_image(image, device=torch.device("cuda")) @@ -150,14 +151,14 @@ async def main(opt): caption = captions[0] - input_file.seek(0) - data = input_file.read() - input_file.close() - if opt.format in ["mrwho","joepenna"]: prefix = f"{i:05}@" i += 1 caption = prefix+caption + elif opt.format == "filename": + postfix = f"_{i}" + i += 1 + caption = caption+postfix if opt.format in ["txt","text","caption"]: out_base_name = os.path.splitext(os.path.basename(img_file_name))[0] @@ -170,14 +171,14 @@ async def main(opt): if opt.format in ["txt","text","caption"]: print("writing caption to: ", out_file) - with open(out_file, "w") as out_file: - out_file.write(caption) + async with aiofiles.open(out_file, "w") as out_file: + await out_file.write(caption) if opt.format in ["filename", "mrwho", "joepenna"]: caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename out_file = get_out_file_name(opt.out_dir, caption, file_ext) - with open(out_file, "wb") as out_file: - out_file.write(data) + async with aiofiles.open(out_file, "wb") as out_file: + await out_file.write(image_bin) elif opt.format == "json": raise NotImplementedError elif opt.format == "parquet":