From d2a7a25da75ef0b9aa5bf0aff9cbad6bd052fd0d Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Wed, 13 Mar 2024 18:38:30 -0400 Subject: [PATCH] update caption_cog.py, transformers, and add peft --- caption_cog.py | 51 ++++++++++++++++++++++++++++++++++-------------- requirements.txt | 3 ++- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/caption_cog.py b/caption_cog.py index 26a4de2..d001584 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -21,7 +21,7 @@ import time import json import logging import re -from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Literal +from typing import TYPE_CHECKING, Generator, Optional, List, Tuple, Dict, Any import torch from torchvision import transforms @@ -30,15 +30,19 @@ from PIL import Image import PIL.ImageOps as ImageOps from pynvml import * -from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer +from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, BitsAndBytesConfig +from transformers.modeling_outputs import BaseModelOutputWithPast from colorama import Fore, Style from plugins.caption_plugins import load_prompt_alteration_plugin +from utils.patch_cog import patch_cog +from data.gen_utils import image_generator, SUPPORTED_EXT -SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] IMAGE_SIZE: int = 490 PATCH_SIZE: int = 14 +patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions + def build_conversation_input_ids( tokenizer: PreTrainedTokenizer, *, @@ -86,17 +90,6 @@ def build_conversation_input_ids( 'images': images, } -def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]: - if do_recurse: - for root, dirs, files in os.walk(image_dir): - for file in files: - if any(file.endswith(ext) for ext in SUPPORTED_EXT): - yield os.path.join(root, file) - else: - for file in os.listdir(image_dir): - if any(file.endswith(ext) for ext in SUPPORTED_EXT): - yield os.path.join(image_dir, file) - def get_gpu_memory_map(): nvmlInit() handle = nvmlDeviceGetHandleByIndex(0) @@ -114,10 +107,25 @@ def save_params(args, gen_kwargs): with open(save_path, "w") as f: f.write(pretty_print) +def create_bnb_config(args): + return BitsAndBytesConfig( + bnb_4bit_compute_dtype="float32", + bnb_4bit_quant_type= "fp4", + bnb_4bit_use_double_quant=False, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + llm_int8_skip_modules=None, + llm_int8_threshold= 6.0, + load_in_4bit=True, + load_in_8bit=False, + quant_method="bitsandbytes" + ) def main(args): prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args) + bnb_config = create_bnb_config(args) + tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') model = AutoModelForCausalLM.from_pretrained( 'THUDM/cogvlm-chat-hf', @@ -125,7 +133,8 @@ def main(args): low_cpu_mem_usage=True, trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor #revision=... # no one is actually doing this - load_in_4bit=not args.disable_4bit, + #load_in_4bit=not args.disable_4bit, + quantization_config=bnb_config, ) do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None @@ -214,12 +223,24 @@ def main(args): 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)], + 'output_hidden_states': True, + 'return_dict': True } + # print(f"** inputs type: {type(inputs)}") # dict + # print(f"** inputs len: {len(inputs)}") # 4 + # print(f"** inputs keys: {inputs.keys()}") # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'images']) + # print(f"** inputs['images'] shape: {inputs['images'].shape}") # list has no shape + # print(f"** image_path: {image_path}") + with torch.no_grad(): #input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) #logging.debug(f"inputs decoded: {input_decoded}") + #print(f"calling generate with input shapes: {inputs['input_ids'].shape}, {inputs['token_type_ids'].shape}, {inputs['attention_mask'].shape}, {inputs['images'][0][0].shape}") + #calling generate with input shapes: torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([3, 490, 490]) outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) + print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}") + #print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}") len_inputs = inputs['input_ids'].shape[1] outputs_without_prompt = outputs[:, len_inputs:] diff --git a/requirements.txt b/requirements.txt index 0d1df20..078defe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ torch==2.1.0 torchvision==0.16.0 -transformers==4.35.0 +transformers>=4.38.2 diffusers[torch]==0.21.4 +peft>=0.9.0 pynvml==11.4.1 bitsandbytes==0.41.1 ftfy==6.1.1