flamingo
This commit is contained in:
parent
aa7e004869
commit
01b77f295e
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import requests
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
|
||||
from huggingface_hub import hf_hub_download
|
||||
from open_flamingo import create_model_and_transforms
|
||||
|
||||
import torch
|
||||
from pynvml import *
|
||||
|
||||
import time
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
|
||||
def get_gpu_memory_map():
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
nvmlShutdown()
|
||||
return info.used/1024/1024
|
||||
|
||||
def remove_duplicates(string):
|
||||
words = string.split(', ') # Split the string into individual words
|
||||
unique_words = []
|
||||
|
||||
for word in words:
|
||||
if word not in unique_words:
|
||||
unique_words.append(word)
|
||||
else:
|
||||
break # Stop appending words once a duplicate is found
|
||||
|
||||
return ', '.join(unique_words)
|
||||
|
||||
def get_examples(example_root, image_processor):
|
||||
examples = []
|
||||
for root, dirs, files in os.walk(example_root):
|
||||
for file in files:
|
||||
ext = os.path.splitext(file)[-1].lower()
|
||||
if ext in SUPPORTED_EXT:
|
||||
#get .txt file of same base name
|
||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
||||
with open(os.path.join(root, txt_file), 'r') as f:
|
||||
caption = f.read()
|
||||
image = Image.open(os.path.join(root, file))
|
||||
vision_x = [image_processor(image).unsqueeze(0)]
|
||||
#vision_x = torch.cat(vision_x, dim=0)
|
||||
#vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
||||
examples.append((caption, vision_x))
|
||||
for x in examples:
|
||||
print(f" ** Example: {x[0]}")
|
||||
return examples
|
||||
|
||||
def main(args):
|
||||
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
||||
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
|
||||
if args.prompt:
|
||||
prompt = args.prompt
|
||||
else:
|
||||
prompt = "<image>: "
|
||||
print(f" using prompt: {prompt}")
|
||||
|
||||
if "mpt7b" in args.model:
|
||||
lang_encoder_path="anas-awadalla/mpt-7b"
|
||||
tokenizer_path="anas-awadalla/mpt-7b"
|
||||
elif "mpt1b" in args.model:
|
||||
lang_encoder_path="anas-awadalla/mpt-1b"
|
||||
tokenizer_path="anas-awadalla/mpt-1b"
|
||||
|
||||
model, image_processor, tokenizer = create_model_and_transforms(
|
||||
clip_vision_encoder_path="ViT-L-14",
|
||||
clip_vision_encoder_pretrained="openai",
|
||||
lang_encoder_path=lang_encoder_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
cross_attn_every_n_layers=1,
|
||||
)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
checkpoint_path = hf_hub_download(args.model, "checkpoint.pt")
|
||||
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
||||
print(f"GPU memory used, before loading model: {get_gpu_memory_map()} MB")
|
||||
model.to(0, dtype=dtype)
|
||||
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
|
||||
|
||||
# examples give few shot learning for captioning the novel image
|
||||
examples = get_examples(args.example_root, image_processor)
|
||||
|
||||
prompt = ""
|
||||
output_prompt = "Output:"
|
||||
per_image_prompt = "<image> " + output_prompt
|
||||
|
||||
for example in iter(examples):
|
||||
prompt += f"{per_image_prompt}{example[0]}<|endofchunk|>"
|
||||
prompt += per_image_prompt # prepare for novel example
|
||||
prompt = prompt.replace("\n", "") # in case captions had newlines
|
||||
print(f" \n** Final full prompt with example pairs: {prompt}")
|
||||
|
||||
# os.walk all files in args.data_root recursively
|
||||
for root, dirs, files in os.walk(args.data_root):
|
||||
for file in files:
|
||||
#get file extension
|
||||
ext = os.path.splitext(file)[1]
|
||||
if ext.lower() in SUPPORTED_EXT:
|
||||
start_time = time.time()
|
||||
|
||||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
|
||||
vision_x = [vx[1][0] for vx in examples]
|
||||
vision_x.append(image_processor(image).unsqueeze(0))
|
||||
vision_x = torch.cat(vision_x, dim=0)
|
||||
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
||||
vision_x = vision_x.to(device, dtype=dtype)
|
||||
|
||||
lang_x = tokenizer(
|
||||
[prompt], # blank for image captioning
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_x.to(device)
|
||||
|
||||
input_ids = lang_x["input_ids"].to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
generated_text = model.generate(
|
||||
vision_x=vision_x,
|
||||
lang_x=input_ids,
|
||||
attention_mask=lang_x["attention_mask"],
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
min_new_tokens=args.min_new_tokens,
|
||||
num_beams=args.num_beams,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
del vision_x
|
||||
del lang_x
|
||||
|
||||
# trim and clean
|
||||
generated_text = tokenizer.decode(generated_text[0][len(input_ids[0]):], skip_special_tokens=True)
|
||||
generated_text = generated_text.split(output_prompt)[0]
|
||||
generated_text = remove_duplicates(generated_text)
|
||||
|
||||
exec_time = time.time() - start_time
|
||||
print(f"* Caption: {generated_text}")
|
||||
|
||||
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
|
||||
|
||||
name = os.path.splitext(full_file_path)[0]
|
||||
if not os.path.exists(name):
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(generated_text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Available models:")
|
||||
print(f" openflamingo/OpenFlamingo-9B-vitl-mpt7b (default)")
|
||||
print(f" openflamingo/OpenFlamingo-3B-vitl-mpt1b")
|
||||
print(f" openflamingo/OpenFlamingo-4B-vitl-rpj3b")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
|
||||
parser.add_argument("--example_root", type=str, default="examples", help="Path to 2-3 precaptioned images to guide generation")
|
||||
parser.add_argument("--model", type=str, default="openflamingo/OpenFlamingo-9B-vitl-mpt7b", help="Model name or path")
|
||||
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available")
|
||||
parser.add_argument("--min_new_tokens", type=int, default=20, help="minimum number of tokens to generate")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=50, help="maximum number of tokens to generate")
|
||||
parser.add_argument("--num_beams", type=int, default=8, help="number of beams, more is more accurate but slower")
|
||||
parser.add_argument("--prompt", type=str, default="Output: ", help="prompt to use for generation, default is 'Output: '")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling, 1.0 is default")
|
||||
parser.add_argument("--top_k", type=int, default=0, help="top_k sampling, 0 is default")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p sampling, 0.9 is default")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty, 1.0 is default")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="length penalty, 1.0 is default")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"** OPEN-FLAMINGO ** Captioning files in: {args.data_root}")
|
||||
print(f"** Using model: {args.model}")
|
||||
main(args)
|
Binary file not shown.
After Width: | Height: | Size: 317 KiB |
|
@ -0,0 +1 @@
|
|||
a red and blue 'Underground' sign found in London
|
BIN
examples/a white bowl filled with creamy hummus placed on a white countertop.jpg
Executable file
BIN
examples/a white bowl filled with creamy hummus placed on a white countertop.jpg
Executable file
Binary file not shown.
After Width: | Height: | Size: 42 KiB |
|
@ -0,0 +1 @@
|
|||
a white bowl filled with creamy hummus placed on a white countertop
|
|
@ -20,3 +20,5 @@ OmegaConf==2.2.3
|
|||
numpy==1.23.5
|
||||
wandb
|
||||
colorama
|
||||
safetensors
|
||||
open-flamingo==2.0.0
|
|
@ -21,6 +21,8 @@ pip install numpy==1.23.5
|
|||
pip install lion-pytorch
|
||||
pip install compel~=1.1.3
|
||||
pip install dadaptation
|
||||
pip install safetensors
|
||||
pip install open-flamingo==2.0.0
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue