diff --git a/caption_cog.py b/caption_cog.py index 321ca15..bc573c1 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -30,7 +30,8 @@ from PIL import Image import PIL.ImageOps as ImageOps from pynvml import * -from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer +from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, \ + AutoProcessor, LlavaProcessor, AutoTokenizer, AutoModelForVision2Seq, LlavaNextProcessor, LlavaNextForConditionalGeneration from transformers.modeling_outputs import BaseModelOutputWithPast from colorama import Fore, Style from unidecode import unidecode @@ -82,7 +83,6 @@ def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "f llm_int8_threshold= 6.0, load_in_4bit=True, load_in_8bit=False, - quant_method="bitsandbytes" ) class BaseModelWrapper: @@ -321,6 +321,66 @@ class CogGLMManager(BaseModelWrapper): caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) return caption +class LlavaNextManager(BaseModelWrapper): + def __init__(self, model_name: str): + super().__init__(model_name) + + def load_model(self, dtype: str = "auto"): + self.tokenizer = LlamaTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True) + self.model.to("cuda") + + def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): + gen_kwargs = self.get_gen_kwargs(args) + image_marker = "" + prompt = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: {image_marker}\n{prompt} ASSISTANT:" + prompt_len = len(prompt) - len(image_marker) + prompt = prompt + args.starts_with + print(f"prompt: {prompt}") + print(f"image: {image}") + inputs = self.processor(prompt, image, return_tensors="pt").to("cuda") + + output = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) + caption = self.processor.decode(output[0], skip_special_tokens=True) + print(f"raw return: {caption}") + caption = caption[prompt_len:] + if args.remove_starts_with: + caption = caption[len(args.starts_with):].strip() + return caption + +# class AutoProcessAndModelManager(BaseModelWrapper): +# def __init__(self, model_name: str): +# super().__init__(model_name) + +# def load_model(self, dtype: str = "auto"): +# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) +# self.processor = AutoProcessor.from_pretrained(self.model_name) +# # bnb_config = None +# # bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb_dtype="nf4") +# #print(bnb_config) +# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name, quantization_config=create_bnb_config()).eval() +# # if bnb_config is None: +# # self.model.to("cuda", dtype=torch.float16) + +# def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): +# messages = [{"role": "user","content": [{"type": "image"},{"type": "text", "text": prompt},]}] +# prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) +# inputs = self.processor(text=prompt, images=[image], return_tensors="pt") +# inputs = {k: v.to("cuda") for k, v in inputs.items()} + +# gen_kwargs = self.get_gen_kwargs(args) + +# generated_ids = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) +# generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True) +# print(type(generated_texts)) +# print(len(generated_texts)) +# print(generated_texts[0]) +# caption = generated_texts[0].split("Assistant:")[-1] + +# return caption + + class CogVLMManager(BaseModelWrapper): def __init__(self, model_name: str): super().__init__(model_name) @@ -429,7 +489,7 @@ def get_model_wrapper(model_name: str): # case x if "moai" in x: # #return MoaiManager(model_name) # return None - case x if "llava" in x: + case "xtuner/llava-llama-3-8b-v1_1-transformers": return XtunerLlavaModelManager(model_name) case "thudm/glm-4v-9b": return CogGLMManager(model_name) @@ -437,6 +497,8 @@ def get_model_wrapper(model_name: str): return CogVLMManager(model_name) case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]: return CogVLMManager(model_name) + case "llava-hf/llava-v1.6-vicuna-7b-hf": + return LlavaNextManager(model_name) case None: return CogVLMManager(model_name) case _: @@ -571,7 +633,7 @@ Notes: c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling" d. num_beams > 1 and do_sample false uses "beam-search decoding" 2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set. - Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. + Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. Some models react differently to these settings. Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation Parameters definitions and use map directly to their API.