EveryDream2trainer/caption.py

140 lines
5.6 KiB
Python
Raw Normal View History

2023-03-20 07:21:13 -06:00
"""
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from PIL import Image
import argparse
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
import torch
from pynvml import *
2023-03-25 10:28:49 -06:00
import time
from colorama import Fore, Style
2023-03-20 07:21:13 -06:00
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def get_gpu_memory_map():
"""Get the current gpu usage.
Returns
-------
usage: dict
Keys are device ids as integers.
Values are memory usage as integers in MB.
"""
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
return info.used/1024/1024
2023-03-25 10:28:49 -06:00
def create_blip2_processor(model_name, device, dtype=torch.float16):
2023-03-20 07:21:13 -06:00
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
2023-03-25 10:28:49 -06:00
args.model, torch_dtype=dtype
2023-03-20 07:21:13 -06:00
)
model.to(device)
model.eval()
print(f"BLIP2 Model loaded: {model_name}")
return processor, model
2023-03-25 10:28:49 -06:00
def create_git_processor(model_name, device, dtype=torch.float16):
2023-03-20 07:21:13 -06:00
processor = GitProcessor.from_pretrained(model_name)
model = GitForCausalLM.from_pretrained(
2023-03-25 10:28:49 -06:00
args.model, torch_dtype=dtype
2023-03-20 07:21:13 -06:00
)
model.to(device)
model.eval()
print(f"GIT Model loaded: {model_name}")
return processor, model
2023-03-25 10:28:49 -06:00
def create_auto_processor(model_name, device, dtype=torch.float16):
2023-03-20 07:21:13 -06:00
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(
2023-03-25 10:28:49 -06:00
args.model, torch_dtype=dtype
2023-03-20 07:21:13 -06:00
)
model.to(device)
model.eval()
print("Auto Model loaded")
return processor, model
def main(args):
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
2023-03-25 10:28:49 -06:00
dtype = torch.float32 if args.force_cpu else torch.float16
2023-03-20 07:21:13 -06:00
2023-03-25 10:28:49 -06:00
if "salesforce/blip2-" in args.model.lower():
print(f"Using BLIP2 model: {args.model}")
processor, model = create_blip2_processor(args.model, device, dtype)
elif "microsoft/git-" in args.model.lower():
print(f"Using GIT model: {args.model}")
processor, model = create_git_processor(args.model, device, dtype)
2023-03-20 07:21:13 -06:00
else:
# try to use auto model? doesn't work with blip/git
2023-03-25 10:28:49 -06:00
processor, model = create_auto_processor(args.model, device, dtype)
2023-03-20 07:21:13 -06:00
2023-03-25 10:28:49 -06:00
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
2023-03-20 07:21:13 -06:00
# os.walk all files in args.data_root recursively
for root, dirs, files in os.walk(args.data_root):
2023-03-25 10:28:49 -06:00
for file in files:
2023-03-20 07:21:13 -06:00
#get file extension
2023-03-25 10:28:49 -06:00
ext = os.path.splitext(file)[1]
2023-03-20 07:21:13 -06:00
if ext.lower() in SUPPORTED_EXT:
2023-03-25 10:28:49 -06:00
full_file_path = os.path.join(root, file)
2023-03-20 07:21:13 -06:00
image = Image.open(full_file_path)
2023-03-25 10:28:49 -06:00
start_time = time.time()
2023-03-20 07:21:13 -06:00
2023-03-25 10:28:49 -06:00
inputs = processor(images=image, return_tensors="pt", max_new_tokens=args.max_new_tokens).to(device, dtype)
2023-03-20 07:21:13 -06:00
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
2023-03-25 10:28:49 -06:00
print(f"file: {file}, caption: {generated_text}")
exec_time = time.time() - start_time
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
2023-03-20 07:21:13 -06:00
# get bare name
name = os.path.splitext(full_file_path)[0]
#name = os.path.join(root, name)
if not os.path.exists(name):
with open(f"{name}.txt", "w") as f:
f.write(generated_text)
if __name__ == "__main__":
2023-03-25 10:28:49 -06:00
print(f"{Fore.CYAN}** Current supported models:{Style.RESET_ALL}")
print(" microsoft/git-base-textcaps")
print(" microsoft/git-large-textcaps")
print(" microsoft/git-large-r-textcaps")
print(" Salesforce/blip2-opt-2.7b - (9GB VRAM or recommend 32GB sys RAM)")
print(" Salesforce/blip2-opt-2.7b-coco - (9GB VRAM or recommend 32GB sys RAM)")
print(" Salesforce/blip2-opt-6.7b - (16.5GB VRAM or recommend 64GB sys RAM)")
print(" Salesforce/blip2-opt-6.7b-coco - (16.5GB VRAM or recommend 64GB sys RAM)")
print()
print(f"{Fore.CYAN} * The following will likely not work on any consumer GPUs or require huge sys RAM on CPU:{Style.RESET_ALL}")
print(" salesforce/blip2-flan-t5-xl")
print(" salesforce/blip2-flan-t5-xl-coco")
print(" salesforce/blip2-flan-t5-xxl")
2023-03-20 07:21:13 -06:00
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
2023-03-25 10:28:49 -06:00
parser.add_argument("--model", type=str, default="salesforce/blip2-opt-2.7b", help="model from huggingface, ex. 'salesforce/blip2-opt-2.7b'")
2023-03-20 07:21:13 -06:00
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available, may be useful to run huge models if you have a lot of system memory")
2023-03-25 10:28:49 -06:00
parser.add_argument("--max_new_tokens", type=int, default=24, help="max length for generated captions")
2023-03-20 07:21:13 -06:00
args = parser.parse_args()
print(f"** Using model: {args.model}")
print(f"** Captioning files in: {args.data_root}")
main(args)