update caption_cog.py, transformers, and add peft

This commit is contained in:
Victor Hall 2024-03-13 18:38:30 -04:00
parent 776edcf9d9
commit d2a7a25da7
2 changed files with 38 additions and 16 deletions

View File

@ -21,7 +21,7 @@ import time
import json
import logging
import re
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal
from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Dict, Any
import torch
from torchvision import transforms
@ -30,15 +30,19 @@ from PIL import Image
import PIL.ImageOps as ImageOps
from pynvml import *
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer
from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, BitsAndBytesConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from colorama import Fore, Style
from plugins.caption_plugins import load_prompt_alteration_plugin
from utils.patch_cog import patch_cog
from data.gen_utils import image_generator, SUPPORTED_EXT
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
IMAGE_SIZE: int = 490
PATCH_SIZE: int = 14
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
def build_conversation_input_ids(
tokenizer: PreTrainedTokenizer,
*,
@ -86,17 +90,6 @@ def build_conversation_input_ids(
'images': images,
}
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)
def get_gpu_memory_map():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
@ -114,10 +107,25 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f:
f.write(pretty_print)
def create_bnb_config(args):
return BitsAndBytesConfig(
bnb_4bit_compute_dtype="float32",
bnb_4bit_quant_type= "fp4",
bnb_4bit_use_double_quant=False,
llm_int8_enable_fp32_cpu_offload=False,
llm_int8_has_fp16_weight=False,
llm_int8_skip_modules=None,
llm_int8_threshold= 6.0,
load_in_4bit=True,
load_in_8bit=False,
quant_method="bitsandbytes"
)
def main(args):
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
bnb_config = create_bnb_config(args)
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
@ -125,7 +133,8 @@ def main(args):
low_cpu_mem_usage=True,
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
#revision=... # no one is actually doing this
load_in_4bit=not args.disable_4bit,
#load_in_4bit=not args.disable_4bit,
quantization_config=bnb_config,
)
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
@ -214,12 +223,24 @@ def main(args):
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
'output_hidden_states': True,
'return_dict': True
}
# print(f"** inputs type: {type(inputs)}") # dict
# print(f"** inputs len: {len(inputs)}") # 4
# print(f"** inputs keys: {inputs.keys()}") # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'images'])
# print(f"** inputs['images'] shape: {inputs['images'].shape}") # list has no shape
# print(f"** image_path: {image_path}")
with torch.no_grad():
#input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
#logging.debug(f"inputs decoded: {input_decoded}")
#print(f"calling generate with input shapes: {inputs['input_ids'].shape}, {inputs['token_type_ids'].shape}, {inputs['attention_mask'].shape}, {inputs['images'][0][0].shape}")
#calling generate with input shapes: torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([3, 490, 490])
outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
len_inputs = inputs['input_ids'].shape[1]
outputs_without_prompt = outputs[:, len_inputs:]

View File

@ -1,7 +1,8 @@
torch==2.1.0
torchvision==0.16.0
transformers==4.35.0
transformers>=4.38.2
diffusers[torch]==0.21.4
peft>=0.9.0
pynvml==11.4.1
bitsandbytes==0.41.1
ftfy==6.1.1