update caption_cog.py, transformers, and add peft
This commit is contained in:
parent
776edcf9d9
commit
d2a7a25da7
|
@ -21,7 +21,7 @@ import time
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal
|
||||
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
@ -30,15 +30,19 @@ from PIL import Image
|
|||
import PIL.ImageOps as ImageOps
|
||||
from pynvml import *
|
||||
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from colorama import Fore, Style
|
||||
|
||||
from plugins.caption_plugins import load_prompt_alteration_plugin
|
||||
from utils.patch_cog import patch_cog
|
||||
from data.gen_utils import image_generator, SUPPORTED_EXT
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
IMAGE_SIZE: int = 490
|
||||
PATCH_SIZE: int = 14
|
||||
|
||||
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
|
||||
|
||||
def build_conversation_input_ids(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
*,
|
||||
|
@ -86,17 +90,6 @@ def build_conversation_input_ids(
|
|||
'images': images,
|
||||
}
|
||||
|
||||
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
|
||||
if do_recurse:
|
||||
for root, dirs, files in os.walk(image_dir):
|
||||
for file in files:
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(root, file)
|
||||
else:
|
||||
for file in os.listdir(image_dir):
|
||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
||||
yield os.path.join(image_dir, file)
|
||||
|
||||
def get_gpu_memory_map():
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
|
@ -114,10 +107,25 @@ def save_params(args, gen_kwargs):
|
|||
with open(save_path, "w") as f:
|
||||
f.write(pretty_print)
|
||||
|
||||
def create_bnb_config(args):
|
||||
return BitsAndBytesConfig(
|
||||
bnb_4bit_compute_dtype="float32",
|
||||
bnb_4bit_quant_type= "fp4",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
llm_int8_enable_fp32_cpu_offload=False,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
llm_int8_skip_modules=None,
|
||||
llm_int8_threshold= 6.0,
|
||||
load_in_4bit=True,
|
||||
load_in_8bit=False,
|
||||
quant_method="bitsandbytes"
|
||||
)
|
||||
|
||||
def main(args):
|
||||
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
||||
|
||||
bnb_config = create_bnb_config(args)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'THUDM/cogvlm-chat-hf',
|
||||
|
@ -125,7 +133,8 @@ def main(args):
|
|||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
|
||||
#revision=... # no one is actually doing this
|
||||
load_in_4bit=not args.disable_4bit,
|
||||
#load_in_4bit=not args.disable_4bit,
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
|
||||
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
|
||||
|
@ -214,12 +223,24 @@ def main(args):
|
|||
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
||||
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
||||
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
|
||||
'output_hidden_states': True,
|
||||
'return_dict': True
|
||||
}
|
||||
|
||||
# print(f"** inputs type: {type(inputs)}") # dict
|
||||
# print(f"** inputs len: {len(inputs)}") # 4
|
||||
# print(f"** inputs keys: {inputs.keys()}") # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'images'])
|
||||
# print(f"** inputs['images'] shape: {inputs['images'].shape}") # list has no shape
|
||||
# print(f"** image_path: {image_path}")
|
||||
|
||||
with torch.no_grad():
|
||||
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
#logging.debug(f"inputs decoded: {input_decoded}")
|
||||
#print(f"calling generate with input shapes: {inputs['input_ids'].shape}, {inputs['token_type_ids'].shape}, {inputs['attention_mask'].shape}, {inputs['images'][0][0].shape}")
|
||||
#calling generate with input shapes: torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([3, 490, 490])
|
||||
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
|
||||
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
|
||||
|
||||
len_inputs = inputs['input_ids'].shape[1]
|
||||
outputs_without_prompt = outputs[:, len_inputs:]
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
torch==2.1.0
|
||||
torchvision==0.16.0
|
||||
transformers==4.35.0
|
||||
transformers>=4.38.2
|
||||
diffusers[torch]==0.21.4
|
||||
peft>=0.9.0
|
||||
pynvml==11.4.1
|
||||
bitsandbytes==0.41.1
|
||||
ftfy==6.1.1
|
||||
|
|
Loading…
Reference in New Issue