make autocaption use async for all file operations, append _n to filename rename
This commit is contained in:
parent
fa0287be84
commit
1d17d8e28a
|
@ -11,6 +11,7 @@ import asyncio
|
||||||
import subprocess
|
import subprocess
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
import io
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
SIZE = 384
|
SIZE = 384
|
||||||
BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
BLIP_MODEL_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||||
|
@ -39,7 +40,7 @@ def get_parser(**parser_kwargs):
|
||||||
nargs="?",
|
nargs="?",
|
||||||
const=True,
|
const=True,
|
||||||
default="filename",
|
default="filename",
|
||||||
help="'filename', 'json', or 'parquet'",
|
help="'filename', 'mrwho', 'txt', or 'caption'",
|
||||||
),
|
),
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nucleus",
|
"--nucleus",
|
||||||
|
@ -115,7 +116,7 @@ async def main(opt):
|
||||||
else:
|
else:
|
||||||
print(f"Model already cached to: {model_cache_path}")
|
print(f"Model already cached to: {model_cache_path}")
|
||||||
|
|
||||||
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=384, vit='base', med_config=config_path)
|
blip_decoder = models.blip.blip_decoder(pretrained=model_cache_path, image_size=SIZE, vit='base', med_config=config_path)
|
||||||
blip_decoder.eval()
|
blip_decoder.eval()
|
||||||
|
|
||||||
print("loading model to cuda")
|
print("loading model to cuda")
|
||||||
|
@ -131,13 +132,13 @@ async def main(opt):
|
||||||
caption = None
|
caption = None
|
||||||
file_ext = os.path.splitext(img_file_name)[1]
|
file_ext = os.path.splitext(img_file_name)[1]
|
||||||
if (file_ext in ext):
|
if (file_ext in ext):
|
||||||
with open(img_file_name, "rb") as input_file:
|
async with aiofiles.open(img_file_name, "rb") as input_file:
|
||||||
print("working image: ", img_file_name)
|
print("working image: ", img_file_name)
|
||||||
|
|
||||||
image = Image.open(input_file)
|
image_bin = await input_file.read()
|
||||||
|
image = Image.open(io.BytesIO(image_bin))
|
||||||
|
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
print("converting to RGB")
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
image = load_image(image, device=torch.device("cuda"))
|
image = load_image(image, device=torch.device("cuda"))
|
||||||
|
@ -150,14 +151,14 @@ async def main(opt):
|
||||||
|
|
||||||
caption = captions[0]
|
caption = captions[0]
|
||||||
|
|
||||||
input_file.seek(0)
|
|
||||||
data = input_file.read()
|
|
||||||
input_file.close()
|
|
||||||
|
|
||||||
if opt.format in ["mrwho","joepenna"]:
|
if opt.format in ["mrwho","joepenna"]:
|
||||||
prefix = f"{i:05}@"
|
prefix = f"{i:05}@"
|
||||||
i += 1
|
i += 1
|
||||||
caption = prefix+caption
|
caption = prefix+caption
|
||||||
|
elif opt.format == "filename":
|
||||||
|
postfix = f"_{i}"
|
||||||
|
i += 1
|
||||||
|
caption = caption+postfix
|
||||||
|
|
||||||
if opt.format in ["txt","text","caption"]:
|
if opt.format in ["txt","text","caption"]:
|
||||||
out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
|
out_base_name = os.path.splitext(os.path.basename(img_file_name))[0]
|
||||||
|
@ -170,14 +171,14 @@ async def main(opt):
|
||||||
|
|
||||||
if opt.format in ["txt","text","caption"]:
|
if opt.format in ["txt","text","caption"]:
|
||||||
print("writing caption to: ", out_file)
|
print("writing caption to: ", out_file)
|
||||||
with open(out_file, "w") as out_file:
|
async with aiofiles.open(out_file, "w") as out_file:
|
||||||
out_file.write(caption)
|
await out_file.write(caption)
|
||||||
|
|
||||||
if opt.format in ["filename", "mrwho", "joepenna"]:
|
if opt.format in ["filename", "mrwho", "joepenna"]:
|
||||||
caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
|
caption = caption.replace("/", "").replace("\\", "") # must clean slashes using filename
|
||||||
out_file = get_out_file_name(opt.out_dir, caption, file_ext)
|
out_file = get_out_file_name(opt.out_dir, caption, file_ext)
|
||||||
with open(out_file, "wb") as out_file:
|
async with aiofiles.open(out_file, "wb") as out_file:
|
||||||
out_file.write(data)
|
await out_file.write(image_bin)
|
||||||
elif opt.format == "json":
|
elif opt.format == "json":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif opt.format == "parquet":
|
elif opt.format == "parquet":
|
||||||
|
|
Loading…
Reference in New Issue