update caption_cog.py, transformers, and add peft
This commit is contained in:
parent
776edcf9d9
commit
d2a7a25da7
|
@ -21,7 +21,7 @@ import time
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal
|
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Dict, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
@ -30,15 +30,19 @@ from PIL import Image
|
||||||
import PIL.ImageOps as ImageOps
|
import PIL.ImageOps as ImageOps
|
||||||
from pynvml import *
|
from pynvml import *
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer
|
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from plugins.caption_plugins import load_prompt_alteration_plugin
|
from plugins.caption_plugins import load_prompt_alteration_plugin
|
||||||
|
from utils.patch_cog import patch_cog
|
||||||
|
from data.gen_utils import image_generator, SUPPORTED_EXT
|
||||||
|
|
||||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
|
||||||
IMAGE_SIZE: int = 490
|
IMAGE_SIZE: int = 490
|
||||||
PATCH_SIZE: int = 14
|
PATCH_SIZE: int = 14
|
||||||
|
|
||||||
|
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
|
||||||
|
|
||||||
def build_conversation_input_ids(
|
def build_conversation_input_ids(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
*,
|
*,
|
||||||
|
@ -86,17 +90,6 @@ def build_conversation_input_ids(
|
||||||
'images': images,
|
'images': images,
|
||||||
}
|
}
|
||||||
|
|
||||||
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
|
|
||||||
if do_recurse:
|
|
||||||
for root, dirs, files in os.walk(image_dir):
|
|
||||||
for file in files:
|
|
||||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
|
||||||
yield os.path.join(root, file)
|
|
||||||
else:
|
|
||||||
for file in os.listdir(image_dir):
|
|
||||||
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
|
|
||||||
yield os.path.join(image_dir, file)
|
|
||||||
|
|
||||||
def get_gpu_memory_map():
|
def get_gpu_memory_map():
|
||||||
nvmlInit()
|
nvmlInit()
|
||||||
handle = nvmlDeviceGetHandleByIndex(0)
|
handle = nvmlDeviceGetHandleByIndex(0)
|
||||||
|
@ -114,10 +107,25 @@ def save_params(args, gen_kwargs):
|
||||||
with open(save_path, "w") as f:
|
with open(save_path, "w") as f:
|
||||||
f.write(pretty_print)
|
f.write(pretty_print)
|
||||||
|
|
||||||
|
def create_bnb_config(args):
|
||||||
|
return BitsAndBytesConfig(
|
||||||
|
bnb_4bit_compute_dtype="float32",
|
||||||
|
bnb_4bit_quant_type= "fp4",
|
||||||
|
bnb_4bit_use_double_quant=False,
|
||||||
|
llm_int8_enable_fp32_cpu_offload=False,
|
||||||
|
llm_int8_has_fp16_weight=False,
|
||||||
|
llm_int8_skip_modules=None,
|
||||||
|
llm_int8_threshold= 6.0,
|
||||||
|
load_in_4bit=True,
|
||||||
|
load_in_8bit=False,
|
||||||
|
quant_method="bitsandbytes"
|
||||||
|
)
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
||||||
|
|
||||||
|
bnb_config = create_bnb_config(args)
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
'THUDM/cogvlm-chat-hf',
|
'THUDM/cogvlm-chat-hf',
|
||||||
|
@ -125,7 +133,8 @@ def main(args):
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
|
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
|
||||||
#revision=... # no one is actually doing this
|
#revision=... # no one is actually doing this
|
||||||
load_in_4bit=not args.disable_4bit,
|
#load_in_4bit=not args.disable_4bit,
|
||||||
|
quantization_config=bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
|
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
|
||||||
|
@ -214,12 +223,24 @@ def main(args):
|
||||||
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
||||||
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
||||||
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
|
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
|
||||||
|
'output_hidden_states': True,
|
||||||
|
'return_dict': True
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# print(f"** inputs type: {type(inputs)}") # dict
|
||||||
|
# print(f"** inputs len: {len(inputs)}") # 4
|
||||||
|
# print(f"** inputs keys: {inputs.keys()}") # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'images'])
|
||||||
|
# print(f"** inputs['images'] shape: {inputs['images'].shape}") # list has no shape
|
||||||
|
# print(f"** image_path: {image_path}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||||
#logging.debug(f"inputs decoded: {input_decoded}")
|
#logging.debug(f"inputs decoded: {input_decoded}")
|
||||||
|
#print(f"calling generate with input shapes: {inputs['input_ids'].shape}, {inputs['token_type_ids'].shape}, {inputs['attention_mask'].shape}, {inputs['images'][0][0].shape}")
|
||||||
|
#calling generate with input shapes: torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([3, 490, 490])
|
||||||
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||||
|
print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
|
||||||
|
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
|
||||||
|
|
||||||
len_inputs = inputs['input_ids'].shape[1]
|
len_inputs = inputs['input_ids'].shape[1]
|
||||||
outputs_without_prompt = outputs[:, len_inputs:]
|
outputs_without_prompt = outputs[:, len_inputs:]
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
torch==2.1.0
|
torch==2.1.0
|
||||||
torchvision==0.16.0
|
torchvision==0.16.0
|
||||||
transformers==4.35.0
|
transformers>=4.38.2
|
||||||
diffusers[torch]==0.21.4
|
diffusers[torch]==0.21.4
|
||||||
|
peft>=0.9.0
|
||||||
pynvml==11.4.1
|
pynvml==11.4.1
|
||||||
bitsandbytes==0.41.1
|
bitsandbytes==0.41.1
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
|
|
Loading…
Reference in New Issue