2023-11-02 19:47:50 -06:00
"""
Copyright [ 2022 - 2023 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
import os
import io
import argparse
import time
2023-11-02 21:24:00 -06:00
import torch
2023-11-02 19:47:50 -06:00
from PIL import Image
from pynvml import *
from transformers import AutoProcessor , AutoModelForVision2Seq
2023-11-03 11:50:47 -06:00
import colorama
2023-11-02 19:47:50 -06:00
GROUNDING = " <grounding> "
SUPPORTED_EXT = [ " .jpg " , " .png " , " .jpeg " , " .bmp " , " .jfif " , " .webp " ]
def get_gpu_memory_map ( ) :
nvmlInit ( )
handle = nvmlDeviceGetHandleByIndex ( 0 )
info = nvmlDeviceGetMemoryInfo ( handle )
nvmlShutdown ( )
return info . used / 1024 / 1024
def remove_starting_string ( a , b ) :
if b . startswith ( a ) :
return b [ len ( a ) : ] # Remove string A from the beginning of string B
2023-11-02 20:53:10 -06:00
elif b . strip ( ) . startswith ( a . strip ( ) ) :
return b . strip ( ) [ len ( a . strip ( ) ) : ]
2023-11-02 19:47:50 -06:00
else :
return b
def main ( args ) :
2023-11-02 20:53:10 -06:00
model = AutoModelForVision2Seq . from_pretrained ( " microsoft/kosmos-2-patch14-224 " )
processor = AutoProcessor . from_pretrained ( " microsoft/kosmos-2-patch14-224 " )
2023-11-02 21:24:00 -06:00
dtype = torch . float32
2023-11-02 20:53:10 -06:00
if not args . cpu :
2023-11-02 21:24:00 -06:00
if args . dtype == " fp16 " :
dtype = torch . float16
elif args . dtype == " bf16 " :
dtype = torch . bfloat16
elif args . dtype == " fp32 " :
dtype = torch . float32
model = model . to ( dtype = dtype ) . cuda ( )
2023-11-02 20:53:10 -06:00
print ( f " Using cuda, model dtype: { model . dtype } " )
else :
print ( f " Using cpu, model dtype: { model . dtype } " )
for root , dirs , files in os . walk ( args . data_root ) :
2023-11-02 19:47:50 -06:00
for file in files :
#get file extension
ext = os . path . splitext ( file ) [ 1 ]
if ext . lower ( ) in SUPPORTED_EXT :
start_time = time . time ( )
full_file_path = os . path . join ( root , file )
image = Image . open ( full_file_path )
2023-11-02 20:53:10 -06:00
2023-11-02 19:47:50 -06:00
full_file_path = os . path . join ( root , file )
image = Image . open ( full_file_path )
2024-03-03 13:18:54 -07:00
if args . phrase_mode :
text = GROUNDING + " " . join ( [ " <phrase> " + x . strip ( ) + " </phrase> " for x in args . prompt . split ( " , " ) ] )
else :
text = GROUNDING + args . prompt
inputs = processor ( text = text , images = image , return_tensors = " pt " )
2023-11-02 19:47:50 -06:00
2023-11-02 21:24:00 -06:00
with torch . cuda . amp . autocast ( enabled = args . dtype != " fp32 " , dtype = dtype ) :
generated_ids = model . generate (
pixel_values = inputs [ " pixel_values " ] . cuda ( ) if not args . cpu else inputs [ " pixel_values " ] ,
input_ids = inputs [ " input_ids " ] . cuda ( ) if not args . cpu else inputs [ " input_ids " ] ,
attention_mask = inputs [ " attention_mask " ] . cuda ( ) if not args . cpu else inputs [ " attention_mask " ] ,
image_embeds = None ,
image_embeds_position_mask = inputs [ " image_embeds_position_mask " ] . cuda ( ) if not args . cpu else inputs [ " image_embeds_position_mask " ] ,
use_cache = True ,
max_new_tokens = args . max_new_tokens ,
)
2023-11-02 19:47:50 -06:00
generated_text = processor . batch_decode ( generated_ids , skip_special_tokens = True ) [ 0 ]
processed_text , entities = processor . post_process_generation ( generated_text ) # remove remaining special tokens to get just the caption and entities
if not args . keep_prompt :
processed_text = remove_starting_string ( args . prompt , processed_text )
2023-11-02 20:53:10 -06:00
print ( f " File: { full_file_path } , Generated caption: { processed_text } " )
2023-11-02 19:47:50 -06:00
name = os . path . splitext ( full_file_path ) [ 0 ]
2024-03-03 13:47:44 -07:00
if ( not os . path . exists ( f " { name } .txt " ) or args . overwrite ) and not args . save_entities_only :
2023-11-02 19:47:50 -06:00
with open ( f " { name } .txt " , " w " ) as f :
f . write ( processed_text )
2023-11-02 20:53:10 -06:00
if args . save_entities and ( not os . path . exists ( f " { name } .ent " ) or args . overwrite ) :
2023-11-02 19:47:50 -06:00
with open ( f " { name } .ent " , " w " ) as entities_file :
2023-11-03 11:50:11 -06:00
entities_file . write ( str ( entities ) )
2023-11-02 20:53:10 -06:00
gpu_mb_used = get_gpu_memory_map ( )
2023-11-02 21:24:00 -06:00
print ( f " gpu usage: { gpu_mb_used : .1f } mb, time taken: { time . time ( ) - start_time : .2f } seconds " )
2023-11-02 19:47:50 -06:00
if __name__ == " __main__ " :
print ( " Kosmos-2 captioning script " )
parser = argparse . ArgumentParser ( )
2023-11-02 20:53:10 -06:00
parser . description = " Kosmos-2 captioning script "
2023-11-02 19:47:50 -06:00
parser . add_argument ( " --data_root " , type = str , default = " input " , help = " Path to folder of images to caption " )
2023-11-02 20:53:10 -06:00
parser . add_argument ( " --prompt " , type = str , default = " Describe this image in detail: " , help = " Prompt for generating caption " )
2024-03-03 13:18:54 -07:00
parser . add_argument ( " --phrase_mode " , action = " store_true " , default = False , help = " uses ' phrase mode ' grounding, interprets prompt as csv list of phrases to ground. " )
2023-11-02 19:47:50 -06:00
parser . add_argument ( " --keep_prompt " , action = " store_true " , default = False , help = " will keep the prompt at the start of the caption when saved " )
2023-11-03 11:50:11 -06:00
parser . add_argument ( " --max_new_tokens " , type = int , default = 128 , help = " Maximum number of tokens to generate " )
2023-11-02 19:47:50 -06:00
parser . add_argument ( " --save_entities " , action = " store_true " , default = False , help = " Save coord box with entities to a separate .ent file " )
2024-03-03 13:18:54 -07:00
parser . add_argument ( " --save_entities_only " , action = " store_true " , default = False , help = " Only save coord box with entities to a separate .ent file, do not write caption .txt " )
2023-11-03 11:50:11 -06:00
parser . add_argument ( " --overwrite " , action = " store_true " , default = False , help = " will overwrite .txt and .ent files if they exist " )
2023-11-02 20:53:10 -06:00
parser . add_argument ( " --cpu " , action = " store_true " , default = False , help = " use cpu instead of cuda " )
2023-11-02 21:24:00 -06:00
parser . add_argument ( " --dtype " , type = str , default = " fp16 " , help = " force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16) " )
2023-11-02 19:47:50 -06:00
args = parser . parse_args ( )
2023-11-02 20:53:10 -06:00
parser . print_help ( )
2024-03-03 13:18:54 -07:00
if args . save_entities_only :
args . save_entities = True
2023-11-02 20:53:10 -06:00
if not args . prompt . startswith ( " " ) :
args . prompt = " " + args . prompt
print ( f " Captioning images in { args . data_root } with prompt: { args . prompt } " )
print ( f " Ideas for prompts: " )
print ( f " Describe this image in detail: (default) " )
print ( f " An image of " )
print ( f " A two sentence description of this image: " )
print ( )
2023-11-02 19:47:50 -06:00
main ( args )