From 17054c9d1036e934a2027da7d5d8e49ccb70c80f Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Fri, 15 Mar 2024 00:18:55 -0400 Subject: [PATCH] patch cog to fix issue with transformers --- utils/patch_cog.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 utils/patch_cog.py diff --git a/utils/patch_cog.py b/utils/patch_cog.py new file mode 100644 index 0000000..16147d5 --- /dev/null +++ b/utils/patch_cog.py @@ -0,0 +1,27 @@ +from typing import Any, Dict +import torch +import bitsandbytes as bnb +from transformers.quantizers.quantizer_bnb_4bit import Bnb4BitHfQuantizer, get_module_from_name +from transformers.modeling_utils import PreTrainedModel + +# CogVLM stores inv_freq in the state dictionary but it is not in models._parameters so it cannot be quantized +# was patched in transformers for other models here: https://github.com/huggingface/transformers/pull/28837/files but cog is not part of transformers +def _patched_check_quantized_param( + self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any] + ) -> bool: + + # if "inv_freq" in param_name: # detect failure case + # print("check_quantized_param", param_name) + + module, tensor_name = get_module_from_name(model, param_name) + if ("inv_freq" == tensor_name): # the fix + return False + if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # will throw key error for inv_freq + return True + elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": + return True + else: + return False + +def patch_cog(): + Bnb4BitHfQuantizer.check_quantized_param = _patched_check_quantized_param