adding llava 1.6
This commit is contained in:
parent
0a6aa4815e
commit
938fe5016d
|
@ -30,7 +30,8 @@ from PIL import Image
|
||||||
import PIL.ImageOps as ImageOps
|
import PIL.ImageOps as ImageOps
|
||||||
from pynvml import *
|
from pynvml import *
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer
|
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, \
|
||||||
|
AutoProcessor, LlavaProcessor, AutoTokenizer, AutoModelForVision2Seq, LlavaNextProcessor, LlavaNextForConditionalGeneration
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
|
@ -82,7 +83,6 @@ def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "f
|
||||||
llm_int8_threshold= 6.0,
|
llm_int8_threshold= 6.0,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
quant_method="bitsandbytes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class BaseModelWrapper:
|
class BaseModelWrapper:
|
||||||
|
@ -321,6 +321,66 @@ class CogGLMManager(BaseModelWrapper):
|
||||||
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
class LlavaNextManager(BaseModelWrapper):
|
||||||
|
def __init__(self, model_name: str):
|
||||||
|
super().__init__(model_name)
|
||||||
|
|
||||||
|
def load_model(self, dtype: str = "auto"):
|
||||||
|
self.tokenizer = LlamaTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||||
|
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||||
|
self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||||
|
self.model.to("cuda")
|
||||||
|
|
||||||
|
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||||
|
gen_kwargs = self.get_gen_kwargs(args)
|
||||||
|
image_marker = "<image>"
|
||||||
|
prompt = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: {image_marker}\n{prompt} ASSISTANT:"
|
||||||
|
prompt_len = len(prompt) - len(image_marker)
|
||||||
|
prompt = prompt + args.starts_with
|
||||||
|
print(f"prompt: {prompt}")
|
||||||
|
print(f"image: {image}")
|
||||||
|
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
output = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||||
|
caption = self.processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
print(f"raw return: {caption}")
|
||||||
|
caption = caption[prompt_len:]
|
||||||
|
if args.remove_starts_with:
|
||||||
|
caption = caption[len(args.starts_with):].strip()
|
||||||
|
return caption
|
||||||
|
|
||||||
|
# class AutoProcessAndModelManager(BaseModelWrapper):
|
||||||
|
# def __init__(self, model_name: str):
|
||||||
|
# super().__init__(model_name)
|
||||||
|
|
||||||
|
# def load_model(self, dtype: str = "auto"):
|
||||||
|
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
# self.processor = AutoProcessor.from_pretrained(self.model_name)
|
||||||
|
# # bnb_config = None
|
||||||
|
# # bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb_dtype="nf4")
|
||||||
|
# #print(bnb_config)
|
||||||
|
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name, quantization_config=create_bnb_config()).eval()
|
||||||
|
# # if bnb_config is None:
|
||||||
|
# # self.model.to("cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
# def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||||
|
# messages = [{"role": "user","content": [{"type": "image"},{"type": "text", "text": prompt},]}]
|
||||||
|
# prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
# inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
||||||
|
# inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# gen_kwargs = self.get_gen_kwargs(args)
|
||||||
|
|
||||||
|
# generated_ids = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||||
|
# generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
# print(type(generated_texts))
|
||||||
|
# print(len(generated_texts))
|
||||||
|
# print(generated_texts[0])
|
||||||
|
# caption = generated_texts[0].split("Assistant:")[-1]
|
||||||
|
|
||||||
|
# return caption
|
||||||
|
|
||||||
|
|
||||||
class CogVLMManager(BaseModelWrapper):
|
class CogVLMManager(BaseModelWrapper):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name)
|
||||||
|
@ -429,7 +489,7 @@ def get_model_wrapper(model_name: str):
|
||||||
# case x if "moai" in x:
|
# case x if "moai" in x:
|
||||||
# #return MoaiManager(model_name)
|
# #return MoaiManager(model_name)
|
||||||
# return None
|
# return None
|
||||||
case x if "llava" in x:
|
case "xtuner/llava-llama-3-8b-v1_1-transformers":
|
||||||
return XtunerLlavaModelManager(model_name)
|
return XtunerLlavaModelManager(model_name)
|
||||||
case "thudm/glm-4v-9b":
|
case "thudm/glm-4v-9b":
|
||||||
return CogGLMManager(model_name)
|
return CogGLMManager(model_name)
|
||||||
|
@ -437,6 +497,8 @@ def get_model_wrapper(model_name: str):
|
||||||
return CogVLMManager(model_name)
|
return CogVLMManager(model_name)
|
||||||
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
|
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
|
||||||
return CogVLMManager(model_name)
|
return CogVLMManager(model_name)
|
||||||
|
case "llava-hf/llava-v1.6-vicuna-7b-hf":
|
||||||
|
return LlavaNextManager(model_name)
|
||||||
case None:
|
case None:
|
||||||
return CogVLMManager(model_name)
|
return CogVLMManager(model_name)
|
||||||
case _:
|
case _:
|
||||||
|
@ -571,7 +633,7 @@ Notes:
|
||||||
c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling"
|
c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling"
|
||||||
d. num_beams > 1 and do_sample false uses "beam-search decoding"
|
d. num_beams > 1 and do_sample false uses "beam-search decoding"
|
||||||
2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set.
|
2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set.
|
||||||
Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead.
|
Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. Some models react differently to these settings.
|
||||||
|
|
||||||
Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation
|
Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation
|
||||||
Parameters definitions and use map directly to their API.
|
Parameters definitions and use map directly to their API.
|
||||||
|
|
Loading…
Reference in New Issue