adding llava 1.6

This commit is contained in:
Victor Hall 2024-06-10 15:49:31 -04:00
parent 0a6aa4815e
commit 938fe5016d
1 changed files with 66 additions and 4 deletions

View File

@ -30,7 +30,8 @@ from PIL import Image
import PIL.ImageOps as ImageOps import PIL.ImageOps as ImageOps
from pynvml import * from pynvml import *
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, \
AutoProcessor, LlavaProcessor, AutoTokenizer, AutoModelForVision2Seq, LlavaNextProcessor, LlavaNextForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from colorama import Fore, Style from colorama import Fore, Style
from unidecode import unidecode from unidecode import unidecode
@ -82,7 +83,6 @@ def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "f
llm_int8_threshold= 6.0, llm_int8_threshold= 6.0,
load_in_4bit=True, load_in_4bit=True,
load_in_8bit=False, load_in_8bit=False,
quant_method="bitsandbytes"
) )
class BaseModelWrapper: class BaseModelWrapper:
@ -321,6 +321,66 @@ class CogGLMManager(BaseModelWrapper):
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
return caption return caption
class LlavaNextManager(BaseModelWrapper):
def __init__(self, model_name: str):
super().__init__(model_name)
def load_model(self, dtype: str = "auto"):
self.tokenizer = LlamaTokenizer.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
self.model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
self.model.to("cuda")
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
gen_kwargs = self.get_gen_kwargs(args)
image_marker = "<image>"
prompt = f"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: {image_marker}\n{prompt} ASSISTANT:"
prompt_len = len(prompt) - len(image_marker)
prompt = prompt + args.starts_with
print(f"prompt: {prompt}")
print(f"image: {image}")
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
output = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
caption = self.processor.decode(output[0], skip_special_tokens=True)
print(f"raw return: {caption}")
caption = caption[prompt_len:]
if args.remove_starts_with:
caption = caption[len(args.starts_with):].strip()
return caption
# class AutoProcessAndModelManager(BaseModelWrapper):
# def __init__(self, model_name: str):
# super().__init__(model_name)
# def load_model(self, dtype: str = "auto"):
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# self.processor = AutoProcessor.from_pretrained(self.model_name)
# # bnb_config = None
# # bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb_dtype="nf4")
# #print(bnb_config)
# self.model = AutoModelForVision2Seq.from_pretrained(self.model_name, quantization_config=create_bnb_config()).eval()
# # if bnb_config is None:
# # self.model.to("cuda", dtype=torch.float16)
# def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
# messages = [{"role": "user","content": [{"type": "image"},{"type": "text", "text": prompt},]}]
# prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
# inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
# inputs = {k: v.to("cuda") for k, v in inputs.items()}
# gen_kwargs = self.get_gen_kwargs(args)
# generated_ids = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
# generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
# print(type(generated_texts))
# print(len(generated_texts))
# print(generated_texts[0])
# caption = generated_texts[0].split("Assistant:")[-1]
# return caption
class CogVLMManager(BaseModelWrapper): class CogVLMManager(BaseModelWrapper):
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__(model_name) super().__init__(model_name)
@ -429,7 +489,7 @@ def get_model_wrapper(model_name: str):
# case x if "moai" in x: # case x if "moai" in x:
# #return MoaiManager(model_name) # #return MoaiManager(model_name)
# return None # return None
case x if "llava" in x: case "xtuner/llava-llama-3-8b-v1_1-transformers":
return XtunerLlavaModelManager(model_name) return XtunerLlavaModelManager(model_name)
case "thudm/glm-4v-9b": case "thudm/glm-4v-9b":
return CogGLMManager(model_name) return CogGLMManager(model_name)
@ -437,6 +497,8 @@ def get_model_wrapper(model_name: str):
return CogVLMManager(model_name) return CogVLMManager(model_name)
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]: case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
return CogVLMManager(model_name) return CogVLMManager(model_name)
case "llava-hf/llava-v1.6-vicuna-7b-hf":
return LlavaNextManager(model_name)
case None: case None:
return CogVLMManager(model_name) return CogVLMManager(model_name)
case _: case _:
@ -571,7 +633,7 @@ Notes:
c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling" c. num_beams > 1 and do_sample true uses "beam-search multinomial sampling"
d. num_beams > 1 and do_sample false uses "beam-search decoding" d. num_beams > 1 and do_sample false uses "beam-search decoding"
2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set. 2. Max_length and max_new_tokens are mutually exclusive. If max_new_tokens is set, max_length is ignored. Default is max_length 2048 if nothing set.
Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. Using Max may abruptly end caption, consider modifying prompt or use length_penalty instead. Some models react differently to these settings.
Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation Find more info on the Huggingface Transformers documentation: https://huggingface.co/docs/transformers/main_classes/text_generation
Parameters definitions and use map directly to their API. Parameters definitions and use map directly to their API.