diff --git a/caption.py b/caption.py index 9346950..63d2829 100644 --- a/caption.py +++ b/caption.py @@ -24,6 +24,9 @@ from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProce import torch from pynvml import * +import time +from colorama import Fore, Style + SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] def get_gpu_memory_map(): @@ -39,30 +42,30 @@ def get_gpu_memory_map(): info = nvmlDeviceGetMemoryInfo(handle) return info.used/1024/1024 -def create_blip2_processor(model_name, device): +def create_blip2_processor(model_name, device, dtype=torch.float16): processor = Blip2Processor.from_pretrained(model_name) model = Blip2ForConditionalGeneration.from_pretrained( - args.model, torch_dtype=torch.float16 + args.model, torch_dtype=dtype ) model.to(device) model.eval() print(f"BLIP2 Model loaded: {model_name}") return processor, model -def create_git_processor(model_name, device): +def create_git_processor(model_name, device, dtype=torch.float16): processor = GitProcessor.from_pretrained(model_name) model = GitForCausalLM.from_pretrained( - args.model, torch_dtype=torch.float16 + args.model, torch_dtype=dtype ) model.to(device) model.eval() print(f"GIT Model loaded: {model_name}") return processor, model -def create_auto_processor(model_name, device): +def create_auto_processor(model_name, device, dtype=torch.float16): processor = AutoProcessor.from_pretrained(model_name) model = AutoModel.from_pretrained( - args.model, torch_dtype=torch.float16 + args.model, torch_dtype=dtype ) model.to(device) model.eval() @@ -71,33 +74,38 @@ def create_auto_processor(model_name, device): def main(args): device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu" + dtype = torch.float32 if args.force_cpu else torch.float16 # automodel doesn't work with git/blip - if "salesforce/blip2-" in args.model: - processor, model = create_blip2_processor(args.model, device) - elif "microsoft/git-" in args.model: - processor, model = create_git_processor(args.model, device) + if "salesforce/blip2-" in args.model.lower(): + print(f"Using BLIP2 model: {args.model}") + processor, model = create_blip2_processor(args.model, device, dtype) + elif "microsoft/git-" in args.model.lower(): + print(f"Using GIT model: {args.model}") + processor, model = create_git_processor(args.model, device, dtype) else: # try to use auto model? doesn't work with blip/git - processor, model = create_auto_processor(args.model, device) + processor, model = create_auto_processor(args.model, device, dtype) - print(f"GPU memory used: {get_gpu_memory_map()} MB") + print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB") # os.walk all files in args.data_root recursively for root, dirs, files in os.walk(args.data_root): - for full_file_path in files: + for file in files: #get file extension - ext = os.path.splitext(full_file_path)[1] + ext = os.path.splitext(file)[1] if ext.lower() in SUPPORTED_EXT: - full_file_path = os.path.join(root, full_file_path) + full_file_path = os.path.join(root, file) image = Image.open(full_file_path) + start_time = time.time() - inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) + inputs = processor(images=image, return_tensors="pt", max_new_tokens=args.max_new_tokens).to(device, dtype) generated_ids = model.generate(**inputs) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() - print(generated_text) - print(f"GPU memory used: {get_gpu_memory_map()} MB") + print(f"file: {file}, caption: {generated_text}") + exec_time = time.time() - start_time + print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB") # get bare name name = os.path.splitext(full_file_path)[0] @@ -107,23 +115,25 @@ def main(args): f.write(generated_text) if __name__ == "__main__": - print("** Current supported models:") - print(" microsoft/git-base-textcaps") - print(" microsoft/git-large-textcaps") - print(" microsoft/git-large-r-textcaps") - print(" Salesforce/blip2-opt-2.7b (9GB VRAM)") - print(" Salesforce/blip2-opt-2.7b-coco") - print(" * The following will not likely work on any consumer cards:") - print(" Salesforce/blip2-opt-6.7b") - print(" Salesforce/blip2-opt-6.7b-coco") - print(" Salesforce/blip2-flan-t5-xl") - print(" Salesforce/blip2-flan-t5-xl-coco") - print(" Salesforce/blip2-flan-t5-xxl") + print(f"{Fore.CYAN}** Current supported models:{Style.RESET_ALL}") + print(" microsoft/git-base-textcaps") + print(" microsoft/git-large-textcaps") + print(" microsoft/git-large-r-textcaps") + print(" Salesforce/blip2-opt-2.7b - (9GB VRAM or recommend 32GB sys RAM)") + print(" Salesforce/blip2-opt-2.7b-coco - (9GB VRAM or recommend 32GB sys RAM)") + print(" Salesforce/blip2-opt-6.7b - (16.5GB VRAM or recommend 64GB sys RAM)") + print(" Salesforce/blip2-opt-6.7b-coco - (16.5GB VRAM or recommend 64GB sys RAM)") + print() + print(f"{Fore.CYAN} * The following will likely not work on any consumer GPUs or require huge sys RAM on CPU:{Style.RESET_ALL}") + print(" salesforce/blip2-flan-t5-xl") + print(" salesforce/blip2-flan-t5-xl-coco") + print(" salesforce/blip2-flan-t5-xxl") parser = argparse.ArgumentParser() parser.add_argument("--data_root", type=str, default="input", help="Path to images") - parser.add_argument("--model", type=str, default="Salesforce/blip2-opt-2.7b", help="model from huggingface, ex. 'Salesforce/blip2-opt-2.7b'") + parser.add_argument("--model", type=str, default="salesforce/blip2-opt-2.7b", help="model from huggingface, ex. 'salesforce/blip2-opt-2.7b'") parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available, may be useful to run huge models if you have a lot of system memory") + parser.add_argument("--max_new_tokens", type=int, default=24, help="max length for generated captions") args = parser.parse_args() print(f"** Using model: {args.model}")