Merge BNB 4bit. (#770)
# What does this PR do? See #626 <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: krzim <zimmerk4@live.com>
This commit is contained in:
parent
f91e9d282d
commit
16fadcec57
|
@ -252,6 +252,8 @@ You can also quantize the weights with bitsandbytes to reduce the VRAM requireme
|
|||
make run-falcon-7b-instruct-quantize
|
||||
```
|
||||
|
||||
4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.
|
||||
|
||||
## Develop
|
||||
|
||||
```shell
|
||||
|
|
|
@ -22,6 +22,8 @@ mod env_runtime;
|
|||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
Bitsandbytes,
|
||||
BitsandbytesNF4,
|
||||
BitsandbytesFP4,
|
||||
Gptq,
|
||||
}
|
||||
|
||||
|
@ -32,6 +34,12 @@ impl std::fmt::Display for Quantization {
|
|||
Quantization::Bitsandbytes => {
|
||||
write!(f, "bitsandbytes")
|
||||
}
|
||||
Quantization::BitsandbytesNF4 => {
|
||||
write!(f, "bitsandbytes-nf4")
|
||||
}
|
||||
Quantization::BitsandbytesFP4 => {
|
||||
write!(f, "bitsandbytes-fp4")
|
||||
}
|
||||
Quantization::Gptq => {
|
||||
write!(f, "gptq")
|
||||
}
|
||||
|
@ -116,7 +124,8 @@ struct Args {
|
|||
num_shard: Option<usize>,
|
||||
|
||||
/// Whether you want the model to be quantized. This will use `bitsandbytes` for
|
||||
/// quantization on the fly, or `gptq`.
|
||||
/// quantization on the fly, or `gptq`. 4bit quantization is available through
|
||||
/// `bitsandbytes` by providing the `bitsandbytes-fp4` or `bitsandbytes-nf4` options.
|
||||
#[clap(long, env, value_enum)]
|
||||
quantize: Option<Quantization>,
|
||||
|
||||
|
|
|
@ -192,13 +192,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "bitsandbytes"
|
||||
version = "0.38.1"
|
||||
description = "8-bit optimizers and matrix multiplication routines."
|
||||
version = "0.40.2"
|
||||
description = "k-bit optimizers and matrix multiplication routines."
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"},
|
||||
{file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"},
|
||||
{file = "bitsandbytes-0.40.2-py3-none-any.whl", hash = "sha256:f0ae26f40c9230c9add9e7c70a10a5ced36fb6deff39906aec1ce4fd25e6ddc0"},
|
||||
{file = "bitsandbytes-0.40.2.tar.gz", hash = "sha256:808ac966272c63bccb2be6d77365275a4c28f1fa348d33656e670de3cab40fc4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1751,6 +1751,42 @@ tensorflow = ["tensorflow (>=2.11.0)"]
|
|||
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "numpy (>=1.21.6)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "setuptools-rust (>=1.5.2)"]
|
||||
torch = ["torch (>=1.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.11.1"
|
||||
description = "Fundamental algorithms for scientific computing in Python"
|
||||
optional = false
|
||||
python-versions = "<3.13,>=3.9"
|
||||
files = [
|
||||
{file = "scipy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aec8c62fbe52914f9cf28d846cf0401dd80ab80788bbab909434eb336ed07c04"},
|
||||
{file = "scipy-1.11.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:3b9963798df1d8a52db41a6fc0e6fa65b1c60e85d73da27ae8bb754de4792481"},
|
||||
{file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e8eb42db36526b130dfbc417609498a6192381abc1975b91e3eb238e0b41c1a"},
|
||||
{file = "scipy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:366a6a937110d80dca4f63b3f5b00cc89d36f678b2d124a01067b154e692bab1"},
|
||||
{file = "scipy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:08d957ca82d3535b3b9ba6c8ff355d78fe975271874e2af267cb5add5bd78625"},
|
||||
{file = "scipy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:e866514bc2d660608447b6ba95c8900d591f2865c07cca0aa4f7ff3c4ca70f30"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ba94eeef3c9caa4cea7b402a35bb02a5714ee1ee77eb98aca1eed4543beb0f4c"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:512fdc18c65f76dadaca139348e525646d440220d8d05f6d21965b8d4466bccd"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cce154372f0ebe88556ed06d7b196e9c2e0c13080ecb58d0f35062dc7cc28b47"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4bb943010203465ac81efa392e4645265077b4d9e99b66cf3ed33ae12254173"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:249cfa465c379c9bb2c20123001e151ff5e29b351cbb7f9c91587260602c58d0"},
|
||||
{file = "scipy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:ffb28e3fa31b9c376d0fb1f74c1f13911c8c154a760312fbee87a21eb21efe31"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:39154437654260a52871dfde852adf1b93b1d1bc5dc0ffa70068f16ec0be2624"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b588311875c58d1acd4ef17c983b9f1ab5391755a47c3d70b6bd503a45bfaf71"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d51565560565a0307ed06fa0ec4c6f21ff094947d4844d6068ed04400c72d0c3"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b41a0f322b4eb51b078cb3441e950ad661ede490c3aca66edef66f4b37ab1877"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:396fae3f8c12ad14c5f3eb40499fd06a6fef8393a6baa352a652ecd51e74e029"},
|
||||
{file = "scipy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:be8c962a821957fdde8c4044efdab7a140c13294997a407eaee777acf63cbf0c"},
|
||||
{file = "scipy-1.11.1.tar.gz", hash = "sha256:fb5b492fa035334fd249f0973cc79ecad8b09c604b42a127a677b45a9a3d4289"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.21.6,<1.28.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"]
|
||||
doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"]
|
||||
test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
|
||||
|
||||
[[package]]
|
||||
name = "sentencepiece"
|
||||
version = "0.1.99"
|
||||
|
@ -2425,5 +2461,5 @@ quantize = ["accelerate", "datasets", "texttable"]
|
|||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "93fd0873b3e16c10b67a84216a84f5eb2f5067cb3ff8cb912446cc6a7fa9c030"
|
||||
python-versions = ">=3.9,<3.13"
|
||||
content-hash = "2abb80833b678452cfc73464fc5b2e48d74b2672bd987240041a33c724a74000"
|
||||
|
|
|
@ -8,7 +8,7 @@ authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
|||
text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.9"
|
||||
python = ">=3.9,<3.13"
|
||||
protobuf = "^4.21.7"
|
||||
grpcio = "^1.51.1"
|
||||
grpcio-status = "^1.51.1"
|
||||
|
@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
|
|||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.19.0", optional = true }
|
||||
bitsandbytes = { version = "^0.38.1", optional = true }
|
||||
bitsandbytes = { version = "^0.40.0", optional = true }
|
||||
safetensors = "0.3.1"
|
||||
loguru = "^0.6.0"
|
||||
opentelemetry-api = "^1.15.0"
|
||||
|
@ -32,6 +32,7 @@ texttable = { version = "^1.6.7", optional = true }
|
|||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = "^0.4.0"
|
||||
torch = {version = "^2.0.1+cu118", source = "pytorch-gpu-src"}
|
||||
scipy = "^1.11.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
|
|
|
@ -1,58 +1,59 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "4.0"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
click==8.1.6 ; python_version >= "3.9" and python_version < "4.0"
|
||||
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
networkx==3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
numpy==1.25.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
peft==0.4.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "4.0"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "4.0"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
accelerate==0.19.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
cmake==3.27.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2023.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.56.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
lit==16.0.6 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.25.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
peft==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==4.23.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
psutil==5.9.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2023.6.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.11.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==68.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
torch==2.0.1+cu118 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
triton==2.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
|
|
|
@ -14,6 +14,8 @@ app = typer.Typer()
|
|||
|
||||
class Quantization(str, Enum):
|
||||
bitsandbytes = "bitsandbytes"
|
||||
bitsandbytes_nf4 = "bitsandbytes-nf4"
|
||||
bitsandbytes_fp4 = "bitsandbytes-fp4"
|
||||
gptq = "gptq"
|
||||
|
||||
|
||||
|
|
|
@ -255,7 +255,10 @@ def get_model(
|
|||
raise ValueError(
|
||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||
raise ValueError(
|
||||
"4bit quantization is not supported for AutoModel"
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
from bitsandbytes.nn import Int8Params, Params4bit
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
@ -140,6 +140,39 @@ class Linear8bitLt(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class Linear4bit(nn.Module):
|
||||
def __init__(self, weight, bias, quant_type):
|
||||
super().__init__()
|
||||
self.weight = Params4bit(
|
||||
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
||||
)
|
||||
self.compute_dtype = None
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, "quant_state", None) is None:
|
||||
print(
|
||||
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
|
||||
)
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(
|
||||
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
|
||||
)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
|
@ -152,6 +185,18 @@ def get_linear(weight, bias, quantize):
|
|||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "bitsandbytes-fp4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="fp4",
|
||||
)
|
||||
elif quantize == "bitsandbytes-nf4":
|
||||
linear = Linear4bit(
|
||||
weight,
|
||||
bias,
|
||||
quant_type="nf4",
|
||||
)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
|
|
Loading…
Reference in New Issue