make autocaption use async for all file operations, append _n to filename rename
This commit is contained in:
parent
fa0287be84
commit
1d17d8e28a
|
@ -11,6 +11,7 @@ import asyncio
|
|||
import subprocess
|
||||
import numpy as np
|
||||
import io
|
||||
import aiofiles
|
||||
|
||||
SIZE = 384
|
||||
BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
|
@ -39,7 +40,7 @@ def get_parser(**parser_kwargs):
|
|||
nargs="?",
|
||||
const=True,
|
||||
default="filename",
|
||||
help="'filename', 'json', or 'parquet'",
|
||||
help="'filename', 'mrwho', 'txt', or 'caption'",
|
||||
),
|
||||
parser.add_argument(
|
||||
"--nucleus",
|
||||
|
@ -115,7 +116,7 @@ async def main(opt):
|
|||
else:
|
||||
print(f"Model already cached to: {model_cache_path}")
|
||||
|
||||
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
|
||||
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
|
||||
blip_decoder.eval()
|
||||
|
||||
print("loading model to cuda")
|
||||
|
@ -131,13 +132,13 @@ async def main(opt):
|
|||
caption = None
|
||||
file_ext = os.path.splitext(img_file_name)[1]
|
||||
if (file_ext in ext):
|
||||
with open(img_file_name, "rb") as input_file:
|
||||
async with aiofiles.open(img_file_name, "rb") as input_file:
|
||||
print("working image: ", img_file_name)
|
||||
|
||||
image = Image.open(input_file)
|
||||
image_bin = await input_file.read()
|
||||
image = Image.open(io.BytesIO(image_bin))
|
||||
|
||||
if not image.mode == "RGB":
|
||||
print("converting to RGB")
|
||||
image = image.convert("RGB")
|
||||
|
||||
image = load_image(image, device=torch.device("cuda"))
|
||||
|
@ -150,14 +151,14 @@ async def main(opt):
|
|||
|
||||
caption = captions[0]
|
||||
|
||||
input_file.seek(0)
|
||||
data = input_file.read()
|
||||
input_file.close()
|
||||
|
||||
if opt.format in ["mrwho","joepenna"]:
|
||||
prefix = f"{i:05}@"
|
||||
i += 1
|
||||
caption = prefix+caption
|
||||
elif opt.format == "filename":
|
||||
postfix = f"_{i}"
|
||||
i += 1
|
||||
caption = caption+postfix
|
||||
|
||||
if opt.format in ["txt","text","caption"]:
|
||||
out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
|
||||
|
@ -170,14 +171,14 @@ async def main(opt):
|
|||
|
||||
if opt.format in ["txt","text","caption"]:
|
||||
print("writing caption to: ", out_file)
|
||||
with open(out_file, "w") as out_file:
|
||||
out_file.write(caption)
|
||||
async with aiofiles.open(out_file, "w") as out_file:
|
||||
await out_file.write(caption)
|
||||
|
||||
if opt.format in ["filename", "mrwho", "joepenna"]:
|
||||
caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
|
||||
out_file = get_out_file_name(opt.out_dir, caption, file_ext)
|
||||
with open(out_file, "wb") as out_file:
|
||||
out_file.write(data)
|
||||
async with aiofiles.open(out_file, "wb") as out_file:
|
||||
await out_file.write(image_bin)
|
||||
elif opt.format == "json":
|
||||
raise NotImplementedError
|
||||
elif opt.format == "parquet":
|
||||
|
|
Loading…
Reference in New Issue