2024-01-24 20:21:06 -07:00
"""
Copyright [ 2022 - 2023 ] Victor C Hall
Licensed under the GNU Affero General Public License ;
You may not use this code except in compliance with the License .
You may obtain a copy of the License at
https : / / www . gnu . org / licenses / agpl - 3.0 . en . html
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
import os
import io
import argparse
import time
2024-03-01 23:20:03 -07:00
import json
import logging
import re
2024-03-13 16:38:30 -06:00
from typing import TYPE_CHECKING , Generator , Optional , List , Tuple , Dict , Any
2024-01-24 20:21:06 -07:00
import torch
2024-03-01 23:20:03 -07:00
from torchvision import transforms
2024-01-24 20:21:06 -07:00
from PIL import Image
2024-02-03 17:24:59 -07:00
import PIL . ImageOps as ImageOps
2024-01-24 20:21:06 -07:00
from pynvml import *
2024-05-04 20:24:17 -06:00
from transformers import AutoModelForCausalLM , LlamaTokenizer , BitsAndBytesConfig , LlavaForConditionalGeneration , AutoProcessor , LlavaProcessor , AutoTokenizer
2024-03-13 16:38:30 -06:00
from transformers . modeling_outputs import BaseModelOutputWithPast
2024-01-24 20:21:06 -07:00
from colorama import Fore , Style
2024-03-01 23:20:03 -07:00
from plugins . caption_plugins import load_prompt_alteration_plugin
2024-03-13 16:38:30 -06:00
from utils . patch_cog import patch_cog
2024-05-04 20:24:17 -06:00
from utils . ed_logging import configure_logging
2024-03-22 11:27:01 -06:00
from data . generators import image_path_generator , SUPPORTED_EXT
try :
2024-04-02 14:35:49 -06:00
from moai . load_moai import prepare_moai
2024-03-22 11:27:01 -06:00
except ImportError :
print ( " moai not found, skipping " )
2024-03-01 23:20:03 -07:00
2024-04-02 14:35:49 -06:00
Image . MAX_IMAGE_PIXELS = 715827880 * 4 # expand the size limit
2024-03-01 23:20:03 -07:00
IMAGE_SIZE : int = 490
PATCH_SIZE : int = 14
2024-05-04 20:24:17 -06:00
torch . backends . cuda . matmul . allow_tf32 = True
torch . backends . cudnn . benchmark = True
2024-01-24 20:21:06 -07:00
def get_gpu_memory_map ( ) :
nvmlInit ( )
handle = nvmlDeviceGetHandleByIndex ( 0 )
info = nvmlDeviceGetMemoryInfo ( handle )
nvmlShutdown ( )
return info . used / 1024 / 1024
2024-03-01 23:20:03 -07:00
def save_params ( args , gen_kwargs ) :
save_path = os . path . join ( args . image_dir , " caption_cog_params.txt " )
args_dict = {
" args " : vars ( args ) ,
" gen_kwargs " : gen_kwargs ,
}
pretty_print = json . dumps ( args_dict , indent = 4 )
with open ( save_path , " w " ) as f :
f . write ( pretty_print )
2024-03-22 11:27:01 -06:00
def create_bnb_config ( ) :
2024-03-13 16:38:30 -06:00
return BitsAndBytesConfig (
2024-05-04 20:24:17 -06:00
bnb_4bit_compute_dtype = " bfloat16 " ,
2024-03-13 16:38:30 -06:00
bnb_4bit_quant_type = " fp4 " ,
bnb_4bit_use_double_quant = False ,
llm_int8_enable_fp32_cpu_offload = False ,
llm_int8_has_fp16_weight = False ,
llm_int8_skip_modules = None ,
llm_int8_threshold = 6.0 ,
load_in_4bit = True ,
load_in_8bit = False ,
quant_method = " bitsandbytes "
)
2024-03-01 23:20:03 -07:00
2024-05-04 20:24:17 -06:00
class BaseModelWrapper :
def __init__ ( self , model_name ) :
self . model_name = model_name
logging . info ( f " Loading { model_name } " )
def load_model ( self , bits : int = 4 , grad_ckpt : bool = False , lora : bool = False , dtype : str = " fp16 " ) :
self . model = AutoModelForCausalLM . from_pretrained (
self . model_name ,
torch_dtype = torch . float16 ,
low_cpu_mem_usage = True ,
) . to ( 0 )
self . tokenizer = AutoProcessor . from_pretrained ( self . model_name )
return self . model , self . tokenizer
def get_gen_kwargs ( self , args ) :
gen_kwargs = {
" max_length " : args . max_length ,
" do_sample " : args . top_k is not None or args . top_p is not None or args . temp is not None or False ,
" length_penalty " : args . length_penalty ,
" num_beams " : args . num_beams ,
" temperature " : args . temp ,
" top_k " : args . top_k ,
" top_p " : args . top_p ,
" repetition_penalty " : args . repetition_penalty ,
" no_repeat_ngram_size " : args . no_repeat_ngram_size ,
" min_new_tokens " : args . min_new_tokens ,
" max_new_tokens " : args . max_new_tokens ,
" length_penalty " : args . length_penalty ,
}
logging . info ( gen_kwargs )
if args . max_new_tokens is not None :
logging . info ( f " ** max_new_tokens set to { args . max_new_tokens } , ignoring max_length " )
del gen_kwargs [ " max_length " ]
if not gen_kwargs [ " do_sample " ] :
logging . info ( f " ** Using greedy sampling " )
del gen_kwargs [ " top_k " ]
del gen_kwargs [ " top_p " ]
del gen_kwargs [ " temperature " ]
else :
logging . info ( f " ** Sampling enabled " )
return gen_kwargs
def caption ( prompt , args ) :
return " "
class XtunerLlavaModelManager ( BaseModelWrapper ) : # https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers
def __init__ ( self , model_name : str = " xtuner/llava-llama-3-8b-v1_1-transformers " ) :
self . model_name = " xtuner/llava-llama-3-8b-v1_1-transformers "
super ( ) . __init__ ( model_name )
def load_model ( self , bits : int = 4 , grad_ckpt : bool = False , lora : bool = False , dtype : str = " fp16 " ) :
self . model = LlavaForConditionalGeneration . from_pretrained (
#self.model = AutoModelForCausalLM.from_pretrained(
self . model_name ,
torch_dtype = torch . float16 ,
low_cpu_mem_usage = True ,
#quantization_config=create_bnb_config()
) . to ( 0 )
self . processor = LlavaProcessor . from_pretrained ( self . model_name )
self . tokenizer = AutoTokenizer . from_pretrained ( " xtuner/llava-llama-3-8b-v1_1-transformers " )
print ( f " self.tokenizer: { self . tokenizer } " )
# tokens = self.tokenizer("foo")
# print(f"foo tokens test1: {tokens}")
return self . model , self . tokenizer
def get_inputs ( self , image : Image . Image , prompt : str ) :
inputs = self . processor ( prompt , image , return_tensors = ' pt ' ) . to ( 0 , torch . float16 )
return inputs
def _build_conversational_input_ids ( self , prompt , starts_with ) :
return ( f " <|start_header_id|>user<|end_header_id|> \n \n <image> \n { prompt } <|eot_id|> "
f " <|start_header_id|>assistant<|end_header_id|> \n \n { starts_with } " )
2024-05-05 23:07:19 -06:00
def _clean_caption ( self , caption , args ) :
"""
Clean up the caption by removing any newlines and excess whitespace , and removes some nonsense Llava adds .
"""
logging . debug ( f " **Llava pre-cleaning caption: { caption } " )
2024-05-04 20:24:17 -06:00
caption = caption . split ( " . " )
#sentence_count = min(4, len(caption))
caption = " . " . join ( caption [ 0 : - 1 ] ) + " . "
2024-05-05 23:07:19 -06:00
caption = caption . replace ( " \n " , " " )
caption = caption . replace ( " " , " " )
caption = re . sub ( r " The image does not contain .*? \ . " , " " , caption )
caption = re . sub ( r " Please note that this description is based on .*? \ . " , " " , caption )
caption = re . sub ( r " , adding to .*? overall appearance " , " " , caption )
caption = re . sub ( r " The rest of .*? is not visible in the image, focusing .*? \ . " , " " , caption )
caption = re . sub ( r " hinting at .*? \ . " , " " , caption )
caption = caption . replace ( " , who is the main subject of the image, " , " " )
logging . debug ( f " **Llava post-cleaning caption: { caption } " )
2024-05-04 20:24:17 -06:00
return caption
def caption ( self , prompt , image , args , force_words_ids , bad_words_ids , history = [ ] ) :
gen_kwargs = self . get_gen_kwargs ( args )
prompt = self . _build_conversational_input_ids ( prompt , args . starts_with )
inputs = self . processor ( prompt , image , return_tensors = ' pt ' ) . to ( 0 , torch . float16 )
# inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
2024-05-05 23:07:19 -06:00
# inputs['input_ids'].shape: torch.Size([1, 34])
# inputs['attention_mask'].shape: torch.Size([1, 34])
# inputs['pixel_values'].shape: torch.Size([1, 3, 336, 336])
2024-05-04 20:24:17 -06:00
inputs = {
" input_ids " : inputs [ " input_ids " ] ,
" attention_mask " : inputs [ ' attention_mask ' ] ,
" pixel_values " : inputs [ ' pixel_values ' ] ,
#"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)],
#"output_hidden_states": True,
#"return_dict": True
}
len_inputs = inputs [ ' input_ids ' ] . shape [ 1 ]
outputs = self . model . generate ( * * inputs , * * gen_kwargs , force_words_ids = force_words_ids , bad_words_ids = bad_words_ids )
caption = self . processor . decode ( outputs [ 0 ] [ len_inputs : ] , skip_special_tokens = True )
2024-05-05 23:07:19 -06:00
caption = self . _clean_caption ( caption , args )
2024-05-04 20:24:17 -06:00
return caption
2024-03-22 11:27:01 -06:00
class MoaiManager :
def __init__ ( self , model_name : str ) :
self . model_name = model_name
self . moai_model = None
self . moai_processor = None
self . seg_model = None
self . seg_processor = None
self . od_model = None
self . od_processor = None
self . sgg_model = None
self . ocr_model = None
2024-04-02 14:35:49 -06:00
def load_model ( self , bits : int = 4 , grad_ckpt : bool = False , lora : bool = False , dtype : str = " fp16 " ) :
2024-03-22 11:27:01 -06:00
moai_model , moai_processor , seg_model , seg_processor , od_model , od_processor , sgg_model , ocr_model \
= prepare_moai ( moai_path = self . model_name , bits = bits , grad_ckpt = grad_ckpt , lora = lora , dtype = dtype )
self . moai_model = moai_model
self . moai_processor = moai_processor
self . seg_model = seg_model
self . seg_processor = seg_processor
self . od_model = od_model
self . od_processor = od_processor
self . sgg_model = sgg_model
self . ocr_model = ocr_model
return moai_model , moai_processor
def get_inputs ( self , image : Image . Image , prompt : str ) :
2024-05-04 20:24:17 -06:00
moai_inputs = self . moai_model . demo_process ( image = image ,
prompt = prompt ,
2024-03-22 11:27:01 -06:00
processor = self . moai_processor ,
seg_model = self . seg_model ,
seg_processor = self . seg_processor ,
od_model = self . od_model ,
od_processor = self . od_processor ,
sgg_model = self . sgg_model ,
ocr_model = self . ocr_model ,
2024-04-02 14:35:49 -06:00
device = " cuda:0 " )
2024-03-22 11:27:01 -06:00
return moai_inputs
2024-05-04 20:24:17 -06:00
# def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
# with torch.inference_mode():
# generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
# answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
# return answer
2024-03-22 11:27:01 -06:00
2024-05-04 20:24:17 -06:00
class CogVLMManager ( BaseModelWrapper ) :
2024-03-22 11:27:01 -06:00
def __init__ ( self , model_name : str ) :
2024-05-04 20:24:17 -06:00
super ( ) . __init__ ( model_name )
self . model_name = " THUDM/cogvlm-chat-hf "
patch_cog ( ) # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
2024-03-22 11:27:01 -06:00
def load_model ( self ) :
2024-04-02 14:35:49 -06:00
self . tokenizer = LlamaTokenizer . from_pretrained ( " lmsys/vicuna-7b-v1.5 " )
2024-03-22 11:27:01 -06:00
self . model = AutoModelForCausalLM . from_pretrained (
self . model_name ,
torch_dtype = torch . bfloat16 ,
low_cpu_mem_usage = True ,
trust_remote_code = True ,
quantization_config = create_bnb_config ( )
)
return self . model , self . tokenizer
2024-05-04 20:24:17 -06:00
def _build_conversation_input_ids ( self ,
* ,
query : str ,
history : Optional [ List [ Tuple [ str , str ] ] ] = None ,
images : Optional [ List [ Image . Image ] ] = None ,
starts_with : Optional [ str ] = None ,
) :
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
image_size : int = IMAGE_SIZE
patch_size : int = PATCH_SIZE
assert images is None or len ( images ) < = 1 , f " not support multi images by now. "
history = history or [ ]
2024-03-22 11:27:01 -06:00
2024-05-04 20:24:17 -06:00
text = f " Question: { query } Answer: "
text + = starts_with if starts_with is not None else " "
input_ids = [ self . tokenizer . bos_token_id ]
token_type_ids = [ 0 ]
if images is not None and len ( images ) == 1 :
# vision
transform = transforms . Compose (
[
transforms . Resize (
( image_size , image_size ) , interpolation = transforms . InterpolationMode . BICUBIC
) ,
transforms . ToTensor ( ) ,
transforms . Normalize ( ( 0.48145466 , 0.4578275 , 0.40821073 ) , ( 0.26862954 , 0.26130258 , 0.27577711 ) ) ,
]
)
images = [ transform ( images [ 0 ] ) ]
vision_token_num = ( image_size / / patch_size ) * ( image_size / / patch_size ) + 2
input_ids + = [ self . tokenizer . pad_token_id ] * vision_token_num
token_type_ids + = [ 1 ] * vision_token_num
text_ids = self . tokenizer . encode ( text , add_special_tokens = False )
input_ids + = text_ids
token_type_ids + = [ 0 ] * len ( text_ids )
attention_mask = [ 1 ] * len ( input_ids )
return {
" input_ids " : torch . tensor ( input_ids , dtype = torch . long ) ,
" token_type_ids " : torch . tensor ( token_type_ids , dtype = torch . long ) ,
" attention_mask " : torch . tensor ( attention_mask , dtype = torch . long ) ,
" images " : images ,
2024-03-22 11:27:01 -06:00
}
2024-05-04 20:24:17 -06:00
def caption ( self , prompt , image , args , force_words_ids , bad_words_ids , history = [ ] ) :
gen_kwargs = self . get_gen_kwargs ( args )
inputs = self . _build_conversation_input_ids ( query = prompt , history = history , images = [ image ] , starts_with = args . starts_with )
2024-05-05 23:07:19 -06:00
# inputs['input_ids'].shape: torch.Size([1259])
# inputs['attention_mask'].shape: torch.Size([1259])
# inputs['images'][0].shape: torch.Size([3, 490, 490])
2024-05-04 20:24:17 -06:00
inputs = {
" input_ids " : inputs [ " input_ids " ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" token_type_ids " : inputs [ ' token_type_ids ' ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" attention_mask " : inputs [ ' attention_mask ' ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" images " : [ [ inputs [ " images " ] [ 0 ] . to ( " cuda " ) . to ( torch . bfloat16 ) ] for _ in range ( args . num_beams ) ] ,
" output_hidden_states " : True ,
" return_dict " : True
}
2024-05-05 23:07:19 -06:00
# inputs['input_ids'].shape: torch.Size([1, 1259])
# inputs['attention_mask'].shape: torch.Size([1, 1259])
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
# len(inputs['images'][0]): 1
# len(inputs['images'][0][0]): 3
2024-05-04 20:24:17 -06:00
outputs = self . model . generate ( * * inputs , * * gen_kwargs , force_words_ids = force_words_ids , bad_words_ids = bad_words_ids )
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
2024-03-22 11:27:01 -06:00
2024-05-04 20:24:17 -06:00
len_inputs = inputs [ ' input_ids ' ] . shape [ 1 ]
outputs_without_prompt = outputs [ : , len_inputs : ]
caption = self . tokenizer . decode ( outputs_without_prompt [ 0 ] , skip_special_tokens = True )
return caption
def get_model_wrapper ( model_name : str ) :
2024-03-22 11:27:01 -06:00
if " moai " in model_name :
return MoaiManager ( model_name )
2024-05-04 20:24:17 -06:00
elif " llava " in model_name :
return XtunerLlavaModelManager ( model_name )
2024-03-22 11:27:01 -06:00
else :
return CogVLMManager ( model_name )
2024-05-04 20:24:17 -06:00
def get_inputs_dict ( inputs ) :
inputs = {
" input_ids " : inputs [ " input_ids " ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" token_type_ids " : inputs [ ' token_type_ids ' ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" attention_mask " : inputs [ ' attention_mask ' ] . unsqueeze ( 0 ) . to ( " cuda " ) ,
" images " : [ [ inputs [ " images " ] [ 0 ] . to ( " cuda " ) . to ( torch . bfloat16 ) ] for _ in range ( args . num_beams ) ] ,
" output_hidden_states " : True ,
" return_dict " : True
}
2024-01-24 20:21:06 -07:00
def main ( args ) :
2024-03-01 23:20:03 -07:00
prompt_plugin_fn = load_prompt_alteration_plugin ( args . prompt_plugin , args = args )
2024-05-04 20:24:17 -06:00
model_wrapper = get_model_wrapper ( args . model )
model_wrapper . load_model ( )
2024-02-03 17:24:59 -07:00
args . append = args . append or " "
2024-03-01 23:20:03 -07:00
if len ( args . append ) > 0 :
args . append = " " + args . append . strip ( )
2024-01-24 20:21:06 -07:00
2024-05-04 20:24:17 -06:00
gen_kwargs = model_wrapper . get_gen_kwargs ( args )
2024-01-24 20:21:06 -07:00
force_words_ids = None
if args . force_words is not None :
force_words = args . force_words . split ( " , " ) if args . force_words is not None else [ ]
2024-03-01 23:20:03 -07:00
logging . info ( f " ** force_words: { Fore . LIGHTGREEN_EX } { force_words } { Style . RESET_ALL } " )
2024-05-04 20:24:17 -06:00
# if args.model contains "cog"
if " cog " in args . model :
force_words_ids = model_wrapper . tokenizer ( force_words , add_special_tokens = False ) [ " input_ids " ] if force_words else [ ]
else :
force_words_ids = model_wrapper . tokenizer ( force_words ) [ " input_ids " ] if force_words else [ ]
2024-01-24 20:21:06 -07:00
bad_words_ids = None
if args . bad_words is not None :
bad_words = args . bad_words . split ( " , " ) if args . bad_words is not None else [ ]
2024-03-01 23:20:03 -07:00
logging . info ( f " ** bad_words: { Fore . LIGHTGREEN_EX } { bad_words } { Style . RESET_ALL } " )
2024-05-04 20:24:17 -06:00
bad_words_ids = model_wrapper . tokenizer ( bad_words , add_special_tokens = False ) [ " input_ids " ] if bad_words else [ ]
2024-01-24 20:21:06 -07:00
2024-03-01 23:20:03 -07:00
logging . info ( f " ** gen_kwargs: \n { Fore . LIGHTGREEN_EX } { gen_kwargs } { Style . RESET_ALL } " )
save_params ( args , gen_kwargs )
2024-01-24 20:21:06 -07:00
total_start_time = time . time ( )
i_processed = 0
2024-03-01 23:58:13 -07:00
starts_with = args . starts_with . strip ( ) if args . starts_with is not None else " "
2024-03-01 23:20:03 -07:00
2024-03-22 11:27:01 -06:00
for i , image_path in enumerate ( image_path_generator ( args . image_dir , do_recurse = not args . no_recurse ) ) :
2024-01-24 20:21:06 -07:00
candidate_caption_path = image_path . replace ( os . path . splitext ( image_path ) [ - 1 ] , " .txt " )
if args . no_overwrite and os . path . exists ( candidate_caption_path ) :
2024-03-01 23:20:03 -07:00
logging . warning ( f " Skipping { image_path } , caption already exists. " )
2024-01-24 20:21:06 -07:00
continue
2024-03-01 23:20:03 -07:00
cap_start_time = time . time ( )
2024-01-24 20:21:06 -07:00
image = Image . open ( image_path )
2024-02-03 17:24:59 -07:00
try :
2024-04-02 14:35:49 -06:00
image = image . convert ( " RGB " )
2024-02-03 17:24:59 -07:00
image = ImageOps . exif_transpose ( image )
except Exception as e :
2024-03-01 23:20:03 -07:00
logging . warning ( f " Non-fatal error processing { image_path } : { e } " )
2024-02-03 17:24:59 -07:00
continue
2024-05-04 20:24:17 -06:00
pixel_count = image . height * image . width
if pixel_count < args . min_pixels :
logging . warning ( f " * Image under { args . min_pixels } pixels, skipping. Path: { image_path } " )
continue
2024-03-01 23:20:03 -07:00
logging . debug ( f " __ Prompt before plugin: { Fore . LIGHTGREEN_EX } { args . prompt } { Style . RESET_ALL } " )
prompt = prompt_plugin_fn ( image_path , args = args )
logging . debug ( f " __ Modified prompt after plugin: { Fore . LIGHTGREEN_EX } { prompt } { Style . RESET_ALL } " )
2024-01-24 20:21:06 -07:00
with torch . no_grad ( ) :
2024-05-04 20:24:17 -06:00
#def caption(self, prompt, images, args, force_words_ids, bad_words_ids, history=[]):
caption = model_wrapper . caption ( prompt , image , args , force_words_ids = force_words_ids , bad_words_ids = bad_words_ids )
2024-03-01 23:20:03 -07:00
if not args . remove_starts_with :
# deal with caption starting with comma, etc
if not re . match ( r " ^ \ W " , caption ) :
caption = starts_with + " " + caption
else :
caption = starts_with + caption
2024-02-03 17:24:59 -07:00
caption + = args . append
2024-01-24 20:21:06 -07:00
2024-03-01 23:20:03 -07:00
with open ( candidate_caption_path , " w " ) as f :
2024-01-24 20:21:06 -07:00
f . write ( caption )
vram_gb = get_gpu_memory_map ( )
2024-03-01 23:20:03 -07:00
elapsed_time = time . time ( ) - cap_start_time
2024-05-04 20:24:17 -06:00
logging . info ( f " n: { i : 05 } , VRAM: { Fore . LIGHTYELLOW_EX } { vram_gb : 0.1f } GB { Style . RESET_ALL } , elapsed: { Fore . LIGHTYELLOW_EX } { elapsed_time : 0.1f } { Style . RESET_ALL } sec, sqrt_pixels: { pow ( float ( pixel_count ) , 0.5 ) : 0.1f } , Captioned { Fore . LIGHTYELLOW_EX } { image_path } { Style . RESET_ALL } : " )
2024-03-01 23:20:03 -07:00
logging . info ( f " { Fore . LIGHTCYAN_EX } { caption } { Style . RESET_ALL } " )
2024-01-24 20:21:06 -07:00
i_processed + = 1
if i_processed == 0 :
2024-03-01 23:20:03 -07:00
logging . info ( f " ** No images found in { args . image_dir } with extension in { SUPPORTED_EXT } OR no images left to caption (did you use --no_overwrite?) " )
2024-01-24 20:21:06 -07:00
exit ( 1 )
total_elapsed_time = time . time ( ) - total_start_time
avg_time = total_elapsed_time / i_processed
hh_mm_ss = time . strftime ( " % H: % M: % S " , time . gmtime ( total_elapsed_time ) )
2024-03-01 23:20:03 -07:00
logging . info ( f " ** Done captioning { args . image_dir } with prompt ' { prompt } ' , total elapsed: { hh_mm_ss } (hh_mm_ss), avg: { avg_time : 0.1f } sec/image " )
2024-01-24 20:21:06 -07:00
EXAMPLES = """ ex.
Basic example :
python caption_cog . py - - image_dir / mnt / mydata / kyrie / - - prompt ' Describe this image in detail, including the subject matter and medium of the artwork. '
2024-02-03 17:24:59 -07:00
Use probabilistic sampling by using any of top_k , top_p , or temp :
2024-05-04 20:24:17 -06:00
python caption_cog . py - - image_dir \" c:/users/chadley/my documents/pictures \" --prompt \" What is this? \" --top_p 0.9
2024-01-24 20:21:06 -07:00
Use beam search and probabilistic sampling :
2024-02-03 17:24:59 -07:00
python caption_cog . py - - image_dir \" c:/users/chadley/my documents/pictures \" --prompt \" Write a description. \" --max_new_tokens 75 --num_beams 4 --temp 0.9 --top_k 3 --top_p 0.9 --repetition_penalty 1.0 --no_repeat_ngram_size 0 --min_new_tokens 5
2024-01-24 20:21:06 -07:00
Force " cat " and " dog " and disallow the word " depicts " :
python caption_cog . py - - image_dir / mnt / lcl / nvme / mldata / test - - num_beams 3 - - force_words " cat,dog " - - bad_words " depicts "
2024-02-03 17:24:59 -07:00
Use a lot of beams and try to control the length with length_penalty :
python caption_cog . py - - image_dir / mnt / lcl / nvme / mldata / test - - num_beams 8 - - length_penalty 0.8 - - prompt " Write a single sentence description. "
2024-01-24 20:21:06 -07:00
Notes :
2024-02-03 17:24:59 -07:00
1. Setting top_k , top_p , or temp enables probabilistic sampling ( aka " do_sample " ) , otherwise greedy sampling is used .
a . num_beams 1 and do_sample false uses " greedy decoding "
b . num_beams 1 and do_sample true uses " multinomial sampling "
c . num_beams > 1 and do_sample true uses " beam-search multinomial sampling "
d . num_beams > 1 and do_sample false uses " beam-search decoding "
2. Max_length and max_new_tokens are mutually exclusive . If max_new_tokens is set , max_length is ignored . Default is max_length 2048 if nothing set .
Using Max may abruptly end caption , consider modifying prompt or use length_penalty instead .
Find more info on the Huggingface Transformers documentation : https : / / huggingface . co / docs / transformers / main_classes / text_generation
Parameters definitions and use map directly to their API .
2024-01-24 20:21:06 -07:00
"""
2024-02-03 17:24:59 -07:00
DESCRIPTION = f " ** { Fore . LIGHTBLUE_EX } CogVLM captioning script { Style . RESET_ALL } ** \n Use --help for usage. "
2024-01-24 20:21:06 -07:00
if __name__ == " __main__ " :
2024-02-03 17:24:59 -07:00
argparser = argparse . ArgumentParser ( )
2024-05-05 23:07:19 -06:00
argparser . add_argument ( " --batch_size " , type = int , default = 1 , help = " Batch size for batch processing. Does NOT work with COG! (def: 1) " )
2024-03-01 23:20:03 -07:00
argparser . add_argument ( " --debug " , action = " store_true " , help = " Enable debug logging " )
2024-01-24 20:21:06 -07:00
argparser . add_argument ( " --disable_4bit " , action = " store_true " , help = " Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16. " )
2024-02-03 17:24:59 -07:00
argparser . add_argument ( " --temp " , type = float , default = None , help = " Temperature for sampling " )
argparser . add_argument ( " --num_beams " , type = int , default = 2 , help = " Number of beams for beam search, default 1 (off) " )
argparser . add_argument ( " --top_k " , type = int , default = None , help = " Top-k, filter k highest probability tokens before sampling " )
argparser . add_argument ( " --top_p " , type = float , default = None , help = " Top-p, for sampling, selects from top tokens with cumulative probability >= p " )
2024-01-24 20:21:06 -07:00
argparser . add_argument ( " --repetition_penalty " , type = float , default = 1.0 , help = " Repetition penalty " )
argparser . add_argument ( " --no_repeat_ngram_size " , type = int , default = 0 , help = " No repetition n-gram size " )
argparser . add_argument ( " --min_new_tokens " , type = int , default = 5 , help = " Minimum number of tokens in returned caption. " )
argparser . add_argument ( " --max_new_tokens " , type = int , default = None , help = " Maximum number of tokens in returned caption. " )
argparser . add_argument ( " --max_length " , type = int , default = 2048 , help = " Alternate to max_new_tokens, limits context. " )
2024-02-03 17:24:59 -07:00
argparser . add_argument ( " --length_penalty " , type = float , default = 1.0 , help = " Length penalty, lower values encourage shorter captions. " )
argparser . add_argument ( " --prompt " , type = str , default = " Write a description. " , help = " Prompt that will guide captioning " )
2024-01-24 20:21:06 -07:00
argparser . add_argument ( " --image_dir " , type = str , default = None , help = " Path to folder of images to caption " )
argparser . add_argument ( " --no_overwrite " , action = " store_true " , help = " Skips captioning images that already have a caption file. " )
argparser . add_argument ( " --force_words " , type = str , default = None , help = " Forces the model to include these words in the caption, use CSV format. " )
argparser . add_argument ( " --bad_words " , type = str , default = None , help = " Words that will not be allowed, use CSV format. " )
2024-02-03 17:24:59 -07:00
argparser . add_argument ( " --append " , type = str , default = None , help = " Extra string to append to all captions. ex. ' painted by John Doe ' " )
2024-03-01 23:20:03 -07:00
argparser . add_argument ( " --no_recurse " , action = " store_true " , help = " Do not recurse into subdirectories. " )
argparser . add_argument ( " --prompt_plugin " , type = str , default = None , help = " Function name to modify prompt, edit code to add plugins. " )
argparser . add_argument ( " --starts_with " , type = str , default = None , help = " Force start words on the output caption. " )
argparser . add_argument ( " --remove_starts_with " , action = " store_true " , help = " Removes the starts_with words from the output caption. " )
argparser . add_argument ( " --append_log " , action = " store_true " , help = " Sets logging to append mode. " )
2024-03-22 11:27:01 -06:00
argparser . add_argument ( " --model " , type = str , default = " THUDM/cogvlm-chat-hf " , help = " Model to use for captioning. " )
2024-05-04 20:24:17 -06:00
argparser . add_argument ( " --min_pixels " , type = int , default = 1 , help = " Minimum total pixel size to caption, under the limit will be skipped " )
2024-04-02 14:35:49 -06:00
args , unknown_args = argparser . parse_known_args ( )
2024-05-04 20:24:17 -06:00
configure_logging ( args , " caption_cog.log " )
2024-03-01 23:20:03 -07:00
2024-04-02 14:35:49 -06:00
unknown_args_dict = { }
for i in range ( 0 , len ( unknown_args ) , 2 ) :
key = unknown_args [ i ] . lstrip ( ' - ' ) # Remove the leading '--'
value = unknown_args [ i + 1 ]
unknown_args_dict [ key ] = value
setattr ( args , key , value ) # Add each unknown argument to the args namespace
logging . info ( f " ** Unknown args have been added to args for plugins: { Fore . LIGHTGREEN_EX } { unknown_args_dict } { Style . RESET_ALL } " )
2024-02-03 17:24:59 -07:00
print ( DESCRIPTION )
print ( EXAMPLES )
2024-01-24 20:21:06 -07:00
if args . image_dir is None :
2024-03-01 23:20:03 -07:00
logging . error ( f " ** { Fore . RED } Error: image_dir is required. { Style . RESET_ALL } " )
2024-01-24 20:21:06 -07:00
exit ( 1 )
2024-02-03 20:01:14 -07:00
if not os . path . exists ( args . image_dir ) :
2024-03-01 23:20:03 -07:00
logging . error ( f " ** { Fore . RED } Error: image_dir { args . image_dir } does not exist. { Style . RESET_ALL } " )
2024-02-03 20:01:14 -07:00
exit ( 1 )
2024-03-01 23:20:03 -07:00
startprint = f " ** Running: { args . image_dir } with prompt ' { args . prompt } "
if args . starts_with is not None :
startprint + = f " { args . starts_with } ' "
else :
startprint + = " ' "
startprint + = f " <caption> "
if args . append is not None :
startprint + = f " , and appending: { args . append } "
logging . info ( startprint )
2024-01-24 20:21:06 -07:00
main ( args )