adding llava 1.6
This commit is contained in:
parent
0a6aa4815e
commit
938fe5016d
|
@ -30,7 +30,8 @@ from PIL import Image
|
|||
import PIL.ImageOps as ImageOps
|
||||
from pynvml import *
|
||||
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, \
|
||||
AutoProcessor, LlavaProcessor, AutoTokenizer, AutoModelForVision2Seq, LlavaNextProcessor, LlavaNextForConditionalGeneration
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from colorama import Fore, Style
|
||||
from unidecode import unidecode
|
||||
|
@ -82,7 +83,6 @@ def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "f
|
|||
llm_int8_threshold= 6.0,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
quant_method="bitsandbytes"
|
||||
)
|
||||
|
||||
class BaseModelWrapper:
|
||||
|
@ -321,6 +321,66 @@ class CogGLMManager(BaseModelWrapper):
|
|||
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
||||
return caption
|
||||
|
||||
class LlavaNextManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
|
||||
def load_model(self, dtype: str = "auto"):
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
self.model.to("cuda")
|
||||
|
||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
gen_kwargs = self.get_gen_kwargs(args)
|
||||
image_marker = "<image>"
|
||||
prompt = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: {image_marker}\n{prompt} ASSISTANT:"
|
||||
prompt_len = len(prompt) - len(image_marker)
|
||||
prompt = prompt + args.starts_with
|
||||
print(f"prompt: {prompt}")
|
||||
print(f"image: {image}")
|
||||
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
|
||||
|
||||
output = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
caption = self.processor.decode(output[0], skip_special_tokens=True)
|
||||
print(f"raw return: {caption}")
|
||||
caption = caption[prompt_len:]
|
||||
if args.remove_starts_with:
|
||||
caption = caption[len(args.starts_with):].strip()
|
||||
return caption
|
||||
|
||||
# class AutoProcessAndModelManager(BaseModelWrapper):
|
||||
# def __init__(self, model_name: str):
|
||||
# super().__init__(model_name)
|
||||
|
||||
# def load_model(self, dtype: str = "auto"):
|
||||
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
# self.processor = AutoProcessor.from_pretrained(self.model_name)
|
||||
# # bnb_config = None
|
||||
# # bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb_dtype="nf4")
|
||||
# #print(bnb_config)
|
||||
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name, quantization_config=create_bnb_config()).eval()
|
||||
# # if bnb_config is None:
|
||||
# # self.model.to("cuda", dtype=torch.float16)
|
||||
|
||||
# def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
# messages = [{"role": "user","content": [{"type": "image"},{"type": "text", "text": prompt},]}]
|
||||
# prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
# inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
||||
# inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# gen_kwargs = self.get_gen_kwargs(args)
|
||||
|
||||
# generated_ids = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
# generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
# print(type(generated_texts))
|
||||
# print(len(generated_texts))
|
||||
# print(generated_texts[0])
|
||||
# caption = generated_texts[0].split("Assistant:")[-1]
|
||||
|
||||
# return caption
|
||||
|
||||
|
||||
class CogVLMManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
|
@ -429,7 +489,7 @@ def get_model_wrapper(model_name: str):
|
|||
# case x if "moai" in x:
|
||||
# #return MoaiManager(model_name)
|
||||
# return None
|
||||
case x if "llava" in x:
|
||||
case "xtuner/llava-llama-3-8b-v1_1-transformers":
|
||||
return XtunerLlavaModelManager(model_name)
|
||||
case "thudm/glm-4v-9b":
|
||||
return CogGLMManager(model_name)
|
||||
|
@ -437,6 +497,8 @@ def get_model_wrapper(model_name: str):
|
|||
return CogVLMManager(model_name)
|
||||
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
|
||||
return CogVLMManager(model_name)
|
||||
case "llava-hf/llava-v1.6-vicuna-7b-hf":
|
||||
return LlavaNextManager(model_name)
|
||||
case None:
|
||||
return CogVLMManager(model_name)
|
||||
case _:
|
||||
|
@ -571,7 +633,7 @@ Notes:
|
|||
c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling"
|
||||
d. num_beams > 1 and do_sample false uses "beam-search decoding"
|
||||
2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set.
|
||||
Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead.
|
||||
Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. Some models react differently to these settings.
|
||||
|
||||
Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation
|
||||
Parameters definitions and use map directly to their API.
|
||||
|
|
Loading…
Reference in New Issue