Merge BNB 4bit. (#770)

# What does this PR do?


See #626 
<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: krzim <zimmerk4@live.com>
This commit is contained in:
Nicolas Patry 2023-08-03 23:00:59 +02:00 committed by GitHub
parent f91e9d282d
commit 16fadcec57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 166 additions and 67 deletions

View File

@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
make run-falcon-7b-instruct-quantize
```
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
## Develop
```shell

View File

@ -22,6 +22,8 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization {
Bitsandbytes,
BitsandbytesNF4,
BitsandbytesFP4,
Gptq,
}
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
Quantization::Bitsandbytes => {
write!(f, "bitsandbytes")
}
Quantization::BitsandbytesNF4 => {
write!(f, "bitsandbytes-nf4")
}
Quantization::BitsandbytesFP4 => {
write!(f, "bitsandbytes-fp4")
}
Quantization::Gptq => {
write!(f, "gptq")
}
@ -116,7 +124,8 @@ struct Args {
num_shard: Option<usize>,
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
/// quantization on the fly, or `gptq`.
/// quantization on the fly, or `gptq`. 4bit quantization is available through
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
#[clap(long, env, value_enum)]
quantize: Option<Quantization>,

48
server/poetry.lock generated
View File

@ -192,13 +192,13 @@ files = [
[[package]]
name = "bitsandbytes"
version = "0.38.1"
description = "8-bit optimizers and matrix multiplication routines."
version = "0.40.2"
description = "k-bit optimizers and matrix multiplication routines."
optional = true
python-versions = "*"
files = [
{file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"},
{file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"},
{file = "bitsandbytes-0.40.2-py3-none-any.whl", hash = "sha256:f0ae26f40c9230c9add9e7c70a10a5ced36fb6deff39906aec1ce4fd25e6ddc0"},
{file = "bitsandbytes-0.40.2.tar.gz", hash = "sha256:808ac966272c63bccb2be6d77365275a4c28f1fa348d33656e670de3cab40fc4"},
]
[[package]]
@ -1751,6 +1751,42 @@ tensorflow = ["tensorflow (>=2.11.0)"]
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"]
torch = ["torch (>=1.10)"]
[[package]]
name = "scipy"
version = "1.11.1"
description = "Fundamental algorithms for scientific computing in Python"
optional = false
python-versions = "<3.13,>=3.9"
files = [
{file = "scipy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aec8c62fbe52914f9cf28d846cf0401dd80ab80788bbab909434eb336ed07c04"},
{file = "scipy-1.11.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3b9963798df1d8a52db41a6fc0e6fa65b1c60e85d73da27ae8bb754de4792481"},
{file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e8eb42db36526b130dfbc417609498a6192381abc1975b91e3eb238e0b41c1a"},
{file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:366a6a937110d80dca4f63b3f5b00cc89d36f678b2d124a01067b154e692bab1"},
{file = "scipy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:08d957ca82d3535b3b9ba6c8ff355d78fe975271874e2af267cb5add5bd78625"},
{file = "scipy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:e866514bc2d660608447b6ba95c8900d591f2865c07cca0aa4f7ff3c4ca70f30"},
{file = "scipy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba94eeef3c9caa4cea7b402a35bb02a5714ee1ee77eb98aca1eed4543beb0f4c"},
{file = "scipy-1.11.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:512fdc18c65f76dadaca139348e525646d440220d8d05f6d21965b8d4466bccd"},
{file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cce154372f0ebe88556ed06d7b196e9c2e0c13080ecb58d0f35062dc7cc28b47"},
{file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4bb943010203465ac81efa392e4645265077b4d9e99b66cf3ed33ae12254173"},
{file = "scipy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:249cfa465c379c9bb2c20123001e151ff5e29b351cbb7f9c91587260602c58d0"},
{file = "scipy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:ffb28e3fa31b9c376d0fb1f74c1f13911c8c154a760312fbee87a21eb21efe31"},
{file = "scipy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:39154437654260a52871dfde852adf1b93b1d1bc5dc0ffa70068f16ec0be2624"},
{file = "scipy-1.11.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b588311875c58d1acd4ef17c983b9f1ab5391755a47c3d70b6bd503a45bfaf71"},
{file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d51565560565a0307ed06fa0ec4c6f21ff094947d4844d6068ed04400c72d0c3"},
{file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b41a0f322b4eb51b078cb3441e950ad661ede490c3aca66edef66f4b37ab1877"},
{file = "scipy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:396fae3f8c12ad14c5f3eb40499fd06a6fef8393a6baa352a652ecd51e74e029"},
{file = "scipy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:be8c962a821957fdde8c4044efdab7a140c13294997a407eaee777acf63cbf0c"},
{file = "scipy-1.11.1.tar.gz", hash = "sha256:fb5b492fa035334fd249f0973cc79ecad8b09c604b42a127a677b45a9a3d4289"},
]
[package.dependencies]
numpy = ">=1.21.6,<1.28.0"
[package.extras]
dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
[[package]]
name = "sentencepiece"
version = "0.1.99"
@ -2425,5 +2461,5 @@ quantize = ["accelerate", "datasets", "texttable"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "93fd0873b3e16c10b67a84216a84f5eb2f5067cb3ff8cb912446cc6a7fa9c030"
python-versions = ">=3.9,<3.13"
content-hash = "2abb80833b678452cfc73464fc5b2e48d74b2672bd987240041a33c724a74000"

View File

@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies]
python = "^3.9"
python = ">=3.9,<3.13"
protobuf = "^4.21.7"
grpcio = "^1.51.1"
grpcio-status = "^1.51.1"
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = { version = "^0.19.0", optional = true }
bitsandbytes = { version = "^0.38.1", optional = true }
bitsandbytes = { version = "^0.40.0", optional = true }
safetensors = "0.3.1"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"
@ -32,6 +32,7 @@ texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
peft = "^0.4.0"
torch = {version = "^2.0.1+cu118", source = "pytorch-gpu-src"}
scipy = "^1.11.1"
[tool.poetry.extras]
accelerate = ["accelerate"]

View File

@ -1,58 +1,59 @@
--extra-index-url https://download.pytorch.org/whl/cu118
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
click==8.1.6 ; python_version >= "3.9" and python_version < "4.0"
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "4.0"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
peft==0.4.0 ; python_version >= "3.9" and python_version < "4.0"
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "4.0"
psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "4.0"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.6 ; python_version >= "3.9" and python_version < "3.13"
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13"
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -14,6 +14,8 @@ app = typer.Typer()
class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"

View File

@ -255,7 +255,10 @@ def get_model(
raise ValueError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -9,7 +9,7 @@ from typing import List
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
from bitsandbytes.nn import Int8Params, Params4bit
except ImportError:
HAS_BITS_AND_BYTES = False
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
return out
class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
)
self.compute_dtype = None
self.weight.cuda(weight.device)
self.bias = bias
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
)
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
)
out = out.to(inp_dtype)
return out
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "bitsandbytes-fp4":
linear = Linear4bit(
weight,
bias,
quant_type="fp4",
)
elif quantize == "bitsandbytes-nf4":
linear = Linear4bit(
weight,
bias,
quant_type="nf4",
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight