179 lines
5.7 KiB
Python
179 lines
5.7 KiB
Python
|
import argparse
|
||
|
import glob
|
||
|
import os
|
||
|
from PIL import Image
|
||
|
import sys
|
||
|
from torchvision import transforms
|
||
|
from torchvision.transforms.functional import InterpolationMode
|
||
|
import torch
|
||
|
import aiohttp
|
||
|
import asyncio
|
||
|
|
||
|
SIZE = 384
|
||
|
|
||
|
def get_parser(**parser_kwargs):
|
||
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||
|
parser.add_argument(
|
||
|
"--img_dir",
|
||
|
type=str,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default="input",
|
||
|
help="directory with images to be captioned",
|
||
|
),
|
||
|
parser.add_argument(
|
||
|
"--out_dir",
|
||
|
type=str,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default="output",
|
||
|
help="directory to put captioned images",
|
||
|
),
|
||
|
parser.add_argument(
|
||
|
"--format",
|
||
|
type=str,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default="filename",
|
||
|
help="'filename', 'json', or 'parquet'",
|
||
|
),
|
||
|
parser.add_argument(
|
||
|
"--nucleus",
|
||
|
type=bool,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default=False,
|
||
|
help="use nucleus sampling instead of beam",
|
||
|
),
|
||
|
parser.add_argument(
|
||
|
"--q_factor",
|
||
|
type=float,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default=0.8,
|
||
|
help="adjusts the likelihood of a word being repeated",
|
||
|
),
|
||
|
parser.add_argument(
|
||
|
"--min_length",
|
||
|
type=int,
|
||
|
nargs="?",
|
||
|
const=True,
|
||
|
default=24,
|
||
|
help="adjusts the likelihood of a word being repeated",
|
||
|
)
|
||
|
|
||
|
return parser
|
||
|
|
||
|
def load_image(raw_image, device):
|
||
|
transform = transforms.Compose([
|
||
|
#transforms.CenterCrop(SIZE),
|
||
|
transforms.Resize((SIZE, SIZE), interpolation=InterpolationMode.BICUBIC),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
||
|
])
|
||
|
image = transform(raw_image).unsqueeze(0).to(device)
|
||
|
return image
|
||
|
|
||
|
async def main(opt):
|
||
|
print("starting")
|
||
|
import models.blip
|
||
|
|
||
|
sample = False
|
||
|
if opt.nucleus:
|
||
|
sample = True
|
||
|
|
||
|
input_dir = os.path.join(os.getcwd(), opt.img_dir)
|
||
|
print("input_dir: ", input_dir)
|
||
|
|
||
|
config_path = os.path.join(os.getcwd(), "scripts/BLIP/configs/med_config.json")
|
||
|
|
||
|
model_cache_path = ".cache/model_base_caption_capfilt_large.pth"
|
||
|
model_path = os.path.join(os.getcwd(), model_cache_path)
|
||
|
|
||
|
if not os.path.exists(model_path):
|
||
|
print(f"Downloading model to {model_path}... please wait")
|
||
|
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||
|
async with aiohttp.ClientSession() as session:
|
||
|
async with session.get(blip_model_url) as res:
|
||
|
result = await res.read()
|
||
|
with open(model_path, 'wb') as f:
|
||
|
f.write(result)
|
||
|
print(f"Model cached to: {model_path}")
|
||
|
else:
|
||
|
print(f"Model already cached to: {model_path}")
|
||
|
|
||
|
blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=384, vit='base', med_config=config_path)
|
||
|
blip_decoder.eval()
|
||
|
|
||
|
print("loading model to cuda")
|
||
|
|
||
|
blip_decoder = blip_decoder.to(torch.device("cuda"))
|
||
|
|
||
|
ext = ('.jpg', '.jpeg', '.png', '.webp', '.tif', '.tga', '.tiff', '.bmp', '.gif')
|
||
|
|
||
|
i = 0
|
||
|
|
||
|
for idx, img_file_name in enumerate(glob.iglob(os.path.join(opt.img_dir, "*.*"))):
|
||
|
if img_file_name.endswith(ext):
|
||
|
caption = None
|
||
|
file_ext = os.path.splitext(img_file_name)[1]
|
||
|
if (file_ext in ext):
|
||
|
with open(img_file_name, "rb") as input_file:
|
||
|
print("working image: ", img_file_name)
|
||
|
|
||
|
image = Image.open(input_file)
|
||
|
|
||
|
image = load_image(image, device=torch.device("cuda"))
|
||
|
|
||
|
if opt.nucleus:
|
||
|
captions = blip_decoder.generate(image, sample=True, top_p=opt.q_factor)
|
||
|
else:
|
||
|
captions = blip_decoder.generate(image, sample=sample, num_beams=16, min_length=opt.min_length, \
|
||
|
max_length=48, repetition_penalty=opt.q_factor)
|
||
|
|
||
|
caption = captions[0]
|
||
|
|
||
|
input_file.seek(0)
|
||
|
data = input_file.read()
|
||
|
input_file.close()
|
||
|
|
||
|
if opt.format in ["mrwho","joepenna"]:
|
||
|
prefix = f"{i:05}@"
|
||
|
i += 1
|
||
|
caption = prefix+caption
|
||
|
|
||
|
out_file = os.path.join(opt.out_dir, f"{caption}{file_ext}")
|
||
|
print(" out_file:", out_file)
|
||
|
print()
|
||
|
|
||
|
if opt.format in ["filename","mrwho"]:
|
||
|
#out_file = os.path.join(out_file)
|
||
|
with open(out_file, "wb") as out_file:
|
||
|
out_file.write(data)
|
||
|
elif opt.format == "json":
|
||
|
raise NotImplementedError
|
||
|
elif opt.format == "parquet":
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def isWindows():
|
||
|
return sys.platform.startswith("win")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print("starting")
|
||
|
parser = get_parser()
|
||
|
opt = parser.parse_args()
|
||
|
|
||
|
if opt.format not in ["filename", "json", "mrwho", "joepenna", "parquet"]:
|
||
|
raise ValueError("format must be 'filename', 'json', or 'parquet'")
|
||
|
|
||
|
if (isWindows()):
|
||
|
print("Windows detected, using asyncio.WindowsSelectorEventLoopPolicy")
|
||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||
|
else:
|
||
|
print("Unix detected, using default asyncio event loop policy")
|
||
|
|
||
|
blip_path = os.path.join(os.getcwd(), "scripts/BLIP")
|
||
|
sys.path.append(blip_path)
|
||
|
|
||
|
asyncio.run(main(opt))
|
||
|
|