This commit is contained in:
Victor Hall 2023-06-29 18:12:52 -04:00
parent aa7e004869
commit 01b77f295e
8 changed files with 204 additions and 0 deletions

198
caption_fl.py Normal file
View File

@ -0,0 +1,198 @@
"""
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from PIL import Image
import argparse
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
from huggingface_hub import hf_hub_download
from open_flamingo import create_model_and_transforms
import torch
from pynvml import *
import time
from colorama import Fore, Style
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def get_gpu_memory_map():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
nvmlShutdown()
return info.used/1024/1024
def remove_duplicates(string):
words = string.split(', ') # Split the string into individual words
unique_words = []
for word in words:
if word not in unique_words:
unique_words.append(word)
else:
break # Stop appending words once a duplicate is found
return ', '.join(unique_words)
def get_examples(example_root, image_processor):
examples = []
for root, dirs, files in os.walk(example_root):
for file in files:
ext = os.path.splitext(file)[-1].lower()
if ext in SUPPORTED_EXT:
#get .txt file of same base name
txt_file = os.path.splitext(file)[0] + ".txt"
with open(os.path.join(root, txt_file), 'r') as f:
caption = f.read()
image = Image.open(os.path.join(root, file))
vision_x = [image_processor(image).unsqueeze(0)]
#vision_x = torch.cat(vision_x, dim=0)
#vision_x = vision_x.unsqueeze(1).unsqueeze(0)
examples.append((caption, vision_x))
for x in examples:
print(f" ** Example: {x[0]}")
return examples
def main(args):
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
if args.prompt:
prompt = args.prompt
else:
prompt = "<image>: "
print(f" using prompt: {prompt}")
if "mpt7b" in args.model:
lang_encoder_path="anas-awadalla/mpt-7b"
tokenizer_path="anas-awadalla/mpt-7b"
elif "mpt1b" in args.model:
lang_encoder_path="anas-awadalla/mpt-1b"
tokenizer_path="anas-awadalla/mpt-1b"
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path=lang_encoder_path,
tokenizer_path=tokenizer_path,
cross_attn_every_n_layers=1,
)
tokenizer.padding_side = "left"
checkpoint_path = hf_hub_download(args.model, "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
print(f"GPU memory used, before loading model: {get_gpu_memory_map()} MB")
model.to(0, dtype=dtype)
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
# examples give few shot learning for captioning the novel image
examples = get_examples(args.example_root, image_processor)
prompt = ""
output_prompt = "Output:"
per_image_prompt = "<image> " + output_prompt
for example in iter(examples):
prompt += f"{per_image_prompt}{example[0]}<|endofchunk|>"
prompt += per_image_prompt # prepare for novel example
prompt = prompt.replace("\n", "") # in case captions had newlines
print(f" \n** Final full prompt with example pairs: {prompt}")
# os.walk all files in args.data_root recursively
for root, dirs, files in os.walk(args.data_root):
for file in files:
#get file extension
ext = os.path.splitext(file)[1]
if ext.lower() in SUPPORTED_EXT:
start_time = time.time()
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
vision_x = [vx[1][0] for vx in examples]
vision_x.append(image_processor(image).unsqueeze(0))
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
vision_x = vision_x.to(device, dtype=dtype)
lang_x = tokenizer(
[prompt], # blank for image captioning
return_tensors="pt",
)
lang_x.to(device)
input_ids = lang_x["input_ids"].to(device)
with torch.cuda.amp.autocast(dtype=dtype):
generated_text = model.generate(
vision_x=vision_x,
lang_x=input_ids,
attention_mask=lang_x["attention_mask"],
max_new_tokens=args.max_new_tokens,
min_new_tokens=args.min_new_tokens,
num_beams=args.num_beams,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
del vision_x
del lang_x
# trim and clean
generated_text = tokenizer.decode(generated_text[0][len(input_ids[0]):], skip_special_tokens=True)
generated_text = generated_text.split(output_prompt)[0]
generated_text = remove_duplicates(generated_text)
exec_time = time.time() - start_time
print(f"* Caption: {generated_text}")
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
name = os.path.splitext(full_file_path)[0]
if not os.path.exists(name):
with open(f"{name}.txt", "w") as f:
f.write(generated_text)
if __name__ == "__main__":
print(f"Available models:")
print(f" openflamingo/OpenFlamingo-9B-vitl-mpt7b (default)")
print(f" openflamingo/OpenFlamingo-3B-vitl-mpt1b")
print(f" openflamingo/OpenFlamingo-4B-vitl-rpj3b")
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
parser.add_argument("--example_root", type=str, default="examples", help="Path to 2-3 precaptioned images to guide generation")
parser.add_argument("--model", type=str, default="openflamingo/OpenFlamingo-9B-vitl-mpt7b", help="Model name or path")
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available")
parser.add_argument("--min_new_tokens", type=int, default=20, help="minimum number of tokens to generate")
parser.add_argument("--max_new_tokens", type=int, default=50, help="maximum number of tokens to generate")
parser.add_argument("--num_beams", type=int, default=8, help="number of beams, more is more accurate but slower")
parser.add_argument("--prompt", type=str, default="Output: ", help="prompt to use for generation, default is 'Output: '")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling, 1.0 is default")
parser.add_argument("--top_k", type=int, default=0, help="top_k sampling, 0 is default")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p sampling, 0.9 is default")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty, 1.0 is default")
parser.add_argument("--length_penalty", type=float, default=1.0, help="length penalty, 1.0 is default")
args = parser.parse_args()
print(f"** OPEN-FLAMINGO ** Captioning files in: {args.data_root}")
print(f"** Using model: {args.model}")
main(args)

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

View File

@ -0,0 +1 @@
a red and blue 'Underground' sign found in London

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@ -0,0 +1 @@
a white bowl filled with creamy hummus placed on a white countertop

View File

@ -20,3 +20,5 @@ OmegaConf==2.2.3
numpy==1.23.5
wandb
colorama
safetensors
open-flamingo==2.0.0

View File

@ -21,6 +21,8 @@ pip install numpy==1.23.5
pip install lion-pytorch
pip install compel~=1.1.3
pip install dadaptation
pip install safetensors
pip install open-flamingo==2.0.0
python utils/get_yamls.py
GOTO :eof