Compare commits

...

2 Commits

Author SHA1 Message Date
Victor Hall 642a64d5dc generator 2024-03-15 00:20:37 -04:00
Victor Hall 17054c9d10 patch cog to fix issue with transformers 2024-03-15 00:18:55 -04:00
2 changed files with 42 additions and 0 deletions

15
data/gen_utils.py Normal file
View File

@ -0,0 +1,15 @@
import os
from typing import Generator
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)

27
utils/patch_cog.py Normal file
View File

@ -0,0 +1,27 @@
from typing import Any, Dict
import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer, get_module_from_name
from transformers.modeling_utils import PreTrainedModel
# CogVLM stores inv_freq in the state dictionary but it is not in models._parameters so it cannot be quantized
# was patched in transformers for other models here: https://github.com/huggingface/transformers/pull/28837/files but cog is not part of transformers
def _patched_check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
) -> bool:
# if "inv_freq" in param_name: # detect failure case
# print("check_quantized_param", param_name)
module, tensor_name = get_module_from_name(model, param_name)
if ("inv_freq" == tensor_name): # the fix
return False
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # will throw key error for inv_freq
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
return True
else:
return False
def patch_cog():
Bnb4BitHfQuantizer.check_quantized_param = _patched_check_quantized_param