make autocaption use async for all file operations, append _n to filename rename

This commit is contained in:
Victor Hall 2022-11-15 19:07:28 -05:00
parent fa0287be84
commit 1d17d8e28a
1 changed files with 14 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import asyncio
import subprocess import subprocess
import numpy as np import numpy as np
import io import io
import aiofiles
SIZE = 384 SIZE = 384
BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
@ -39,7 +40,7 @@ def get_parser(**parser_kwargs):
nargs="?", nargs="?",
const=True, const=True,
default="filename", default="filename",
help="'filename', 'json', or 'parquet'", help="'filename', 'mrwho', 'txt', or 'caption'",
), ),
parser.add_argument( parser.add_argument(
"--nucleus", "--nucleus",
@ -115,7 +116,7 @@ async def main(opt):
else: else:
print(f"Model already cached to: {model_cache_path}") print(f"Model already cached to: {model_cache_path}")
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path) blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
blip_decoder.eval() blip_decoder.eval()
print("loading model to cuda") print("loading model to cuda")
@ -131,13 +132,13 @@ async def main(opt):
caption = None caption = None
file_ext = os.path.splitext(img_file_name)[1] file_ext = os.path.splitext(img_file_name)[1]
if (file_ext in ext): if (file_ext in ext):
with open(img_file_name, "rb") as input_file: async with aiofiles.open(img_file_name, "rb") as input_file:
print("working image: ", img_file_name) print("working image: ", img_file_name)
image = Image.open(input_file) image_bin = await input_file.read()
image = Image.open(io.BytesIO(image_bin))
if not image.mode == "RGB": if not image.mode == "RGB":
print("converting to RGB")
image = image.convert("RGB") image = image.convert("RGB")
image = load_image(image, device=torch.device("cuda")) image = load_image(image, device=torch.device("cuda"))
@ -150,14 +151,14 @@ async def main(opt):
caption = captions[0] caption = captions[0]
input_file.seek(0)
data = input_file.read()
input_file.close()
if opt.format in ["mrwho","joepenna"]: if opt.format in ["mrwho","joepenna"]:
prefix = f"{i:05}@" prefix = f"{i:05}@"
i += 1 i += 1
caption = prefix+caption caption = prefix+caption
elif opt.format == "filename":
postfix = f"_{i}"
i += 1
caption = caption+postfix
if opt.format in ["txt","text","caption"]: if opt.format in ["txt","text","caption"]:
out_base_name = os.path.splitext(os.path.basename(img_file_name))[0] out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
@ -170,14 +171,14 @@ async def main(opt):
if opt.format in ["txt","text","caption"]: if opt.format in ["txt","text","caption"]:
print("writing caption to: ", out_file) print("writing caption to: ", out_file)
with open(out_file, "w") as out_file: async with aiofiles.open(out_file, "w") as out_file:
out_file.write(caption) await out_file.write(caption)
if opt.format in ["filename", "mrwho", "joepenna"]: if opt.format in ["filename", "mrwho", "joepenna"]:
caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
out_file = get_out_file_name(opt.out_dir, caption, file_ext) out_file = get_out_file_name(opt.out_dir, caption, file_ext)
with open(out_file, "wb") as out_file: async with aiofiles.open(out_file, "wb") as out_file:
out_file.write(data) await out_file.write(image_bin)
elif opt.format == "json": elif opt.format == "json":
raise NotImplementedError raise NotImplementedError
elif opt.format == "parquet": elif opt.format == "parquet":