make autocaption use async for all file operations, append _n to filename rename

This commit is contained in:
Victor Hall 2022-11-15 19:07:28 -05:00
parent fa0287be84
commit 1d17d8e28a
1 changed files with 14 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import asyncio
import subprocess
import numpy as np
import io
import aiofiles
SIZE = 384
BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
@ -39,7 +40,7 @@ def get_parser(**parser_kwargs):
nargs="?",
const=True,
default="filename",
help="'filename', 'json', or 'parquet'",
help="'filename', 'mrwho', 'txt', or 'caption'",
),
parser.add_argument(
"--nucleus",
@ -115,7 +116,7 @@ async def main(opt):
else:
print(f"Model already cached to: {model_cache_path}")
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
blip_decoder.eval()
print("loading model to cuda")
@ -131,13 +132,13 @@ async def main(opt):
caption = None
file_ext = os.path.splitext(img_file_name)[1]
if (file_ext in ext):
with open(img_file_name, "rb") as input_file:
async with aiofiles.open(img_file_name, "rb") as input_file:
print("working image: ", img_file_name)
image = Image.open(input_file)
image_bin = await input_file.read()
image = Image.open(io.BytesIO(image_bin))
if not image.mode == "RGB":
print("converting to RGB")
image = image.convert("RGB")
image = load_image(image, device=torch.device("cuda"))
@ -150,14 +151,14 @@ async def main(opt):
caption = captions[0]
input_file.seek(0)
data = input_file.read()
input_file.close()
if opt.format in ["mrwho","joepenna"]:
prefix = f"{i:05}@"
i += 1
caption = prefix+caption
elif opt.format == "filename":
postfix = f"_{i}"
i += 1
caption = caption+postfix
if opt.format in ["txt","text","caption"]:
out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
@ -170,14 +171,14 @@ async def main(opt):
if opt.format in ["txt","text","caption"]:
print("writing caption to: ", out_file)
with open(out_file, "w") as out_file:
out_file.write(caption)
async with aiofiles.open(out_file, "w") as out_file:
await out_file.write(caption)
if opt.format in ["filename", "mrwho", "joepenna"]:
caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
out_file = get_out_file_name(opt.out_dir, caption, file_ext)
with open(out_file, "wb") as out_file:
out_file.write(data)
async with aiofiles.open(out_file, "wb") as out_file:
await out_file.write(image_bin)
elif opt.format == "json":
raise NotImplementedError
elif opt.format == "parquet":