From 3c9df21ff8f0627988728388e95f097bb1f89217 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Mon, 18 Nov 2024 17:20:31 +0100 Subject: [PATCH] Add support for compressed-tensors w8a8 int checkpoints (#2745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for compressed-tensors w8a8 int checkpoints This change adds a loader for w8a8 int checkpoints. One large benefit of int8 support is that the corresponding cutlass matmul kernels also work on compute capability 7.5. Evaluation on neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8: | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |---------------|------:|----------------|-----:|-----------------------|---|-----:|---|------| |gsm8k_cot_llama| 3|flexible-extract| 8|exact_match |↑ |0.8431|± |0.0100| | | |strict-match | 8|exact_match |↑ |0.8393|± |0.0101| |ifeval | 4|none | 0|inst_level_loose_acc |↑ |0.8597|± | N/A| | | |none | 0|inst_level_strict_acc |↑ |0.8201|± | N/A| | | |none | 0|prompt_level_loose_acc |↑ |0.7967|± |0.0173| | | |none | 0|prompt_level_strict_acc|↑ |0.7468|± |0.0187| Which is the same ballpark as vLLM. As usual, lots of thanks to Neural Magic/vLLM for the kernels. * Always use dynamic input quantization for w8a8 int It's far less flaky and gives better output. * Use marlin-kernels 0.3.5 * Fix a typo Co-authored-by: drbh * Small fixes --------- Co-authored-by: drbh --- flake.lock | 7 +- flake.nix | 2 +- .../test_compressed_tensors_w8a8_int.json | 104 +++++ ...ompressed_tensors_w8a8_int_all_params.json | 99 +++++ ...test_compressed_tensors_w8a8_int_load.json | 418 ++++++++++++++++++ ...essed_tensors_w8a8_int_dynamic_weight.json | 99 +++++ ...rs_w8a8_int_dynamic_weight_all_params.json | 94 ++++ ..._tensors_w8a8_int_dynamic_weight_load.json | 398 +++++++++++++++++ .../test_compressed_tensors_w8a8_int.py | 90 ++++ ...pressed_tensors_w8a8_int_dynamic_weight.py | 92 ++++ server/poetry.lock | 26 +- server/pyproject.toml | 8 +- .../layers/compressed_tensors/loader.py | 12 + .../layers/compressed_tensors/w8a8_int.py | 241 ++++++++++ .../text_generation_server/utils/weights.py | 5 +- 15 files changed, 1673 insertions(+), 22 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_load.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json create mode 100644 integration-tests/models/test_compressed_tensors_w8a8_int.py create mode 100644 integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py create mode 100644 server/text_generation_server/layers/compressed_tensors/w8a8_int.py diff --git a/flake.lock b/flake.lock index 6d2ff5dc..14860461 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1731674227, - "narHash": "sha256-k/ur37KSc+RXcwwz0tgxeamz6wQ5rsOe5hMepzIdD2s=", + "lastModified": 1731923801, + "narHash": "sha256-SVtXtTGgnKjwPwMLe030l/DVhcm1vH4fXM7tUAPYOZc=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "407b9e22a0b7121bf6e171d67ce0144e3f3e39bf", + "rev": "b87d4b5bede0ffed7da50e9a5246b133c7d618dc", "type": "github" }, "original": { "owner": "huggingface", + "ref": "marlin-kernels-0.3.5", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index f26a983e..cdde7a4c 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.5"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int.json new file mode 100644 index 00000000..1f7e0425 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.4902344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 323, + "logprob": -1.1171875, + "special": false, + "text": " and" + }, + { + "id": 1268, + "logprob": -0.9477539, + "special": false, + "text": " how" + }, + { + "id": 1587, + "logprob": -0.51464844, + "special": false, + "text": " does" + }, + { + "id": 433, + "logprob": -0.043182373, + "special": false, + "text": " it" + }, + { + "id": 1782, + "logprob": -1.0810547, + "special": false, + "text": " differ" + }, + { + "id": 505, + "logprob": -0.005054474, + "special": false, + "text": " from" + }, + { + "id": 8776, + "logprob": -0.47485352, + "special": false, + "text": " traditional" + }, + { + "id": 5780, + "logprob": -0.15112305, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0011291504, + "special": false, + "text": " learning" + }, + { + "id": 5380, + "logprob": -0.31323242, + "special": false, + "text": "?\n" + } + ], + "top_tokens": null + }, + "generated_text": " and how does it differ from traditional machine learning?\n" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json new file mode 100644 index 00000000..c1a789ef --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5380, + "logprob": 0.0, + "special": false, + "text": "?\n" + }, + { + "id": 34564, + "logprob": 0.0, + "special": false, + "text": "Deep" + }, + { + "id": 6975, + "logprob": 0.0, + "special": false, + "text": " learning" + }, + { + "id": 11, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 1101, + "logprob": -1.0947266, + "special": false, + "text": " also" + }, + { + "id": 3967, + "logprob": 0.0, + "special": false, + "text": " known" + }, + { + "id": 439, + "logprob": 0.0, + "special": false, + "text": " as" + }, + { + "id": 30828, + "logprob": 0.0, + "special": false, + "text": " neural" + }, + { + "id": 4009, + "logprob": -0.15563965, + "special": false, + "text": " network" + }, + { + "id": 477, + "logprob": -1.4003906, + "special": false, + "text": " or" + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning?\nDeep learning, also known as neural network or" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_load.json new file mode 100644 index 00000000..a177ee9a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.4902344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 323, + "logprob": -1.1171875, + "special": false, + "text": " and" + }, + { + "id": 1268, + "logprob": -0.9477539, + "special": false, + "text": " how" + }, + { + "id": 1587, + "logprob": -0.51464844, + "special": false, + "text": " does" + }, + { + "id": 433, + "logprob": -0.043182373, + "special": false, + "text": " it" + }, + { + "id": 1782, + "logprob": -1.0810547, + "special": false, + "text": " differ" + }, + { + "id": 505, + "logprob": -0.005054474, + "special": false, + "text": " from" + }, + { + "id": 8776, + "logprob": -0.47485352, + "special": false, + "text": " traditional" + }, + { + "id": 5780, + "logprob": -0.15112305, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0011291504, + "special": false, + "text": " learning" + }, + { + "id": 5380, + "logprob": -0.3173828, + "special": false, + "text": "?\n" + } + ], + "top_tokens": null + }, + "generated_text": " and how does it differ from traditional machine learning?\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.4902344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 323, + "logprob": -1.1171875, + "special": false, + "text": " and" + }, + { + "id": 1268, + "logprob": -0.9477539, + "special": false, + "text": " how" + }, + { + "id": 1587, + "logprob": -0.51464844, + "special": false, + "text": " does" + }, + { + "id": 433, + "logprob": -0.043182373, + "special": false, + "text": " it" + }, + { + "id": 1782, + "logprob": -1.0810547, + "special": false, + "text": " differ" + }, + { + "id": 505, + "logprob": -0.005054474, + "special": false, + "text": " from" + }, + { + "id": 8776, + "logprob": -0.47485352, + "special": false, + "text": " traditional" + }, + { + "id": 5780, + "logprob": -0.15112305, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0011291504, + "special": false, + "text": " learning" + }, + { + "id": 5380, + "logprob": -0.3173828, + "special": false, + "text": "?\n" + } + ], + "top_tokens": null + }, + "generated_text": " and how does it differ from traditional machine learning?\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.4902344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 323, + "logprob": -1.1171875, + "special": false, + "text": " and" + }, + { + "id": 1268, + "logprob": -0.9477539, + "special": false, + "text": " how" + }, + { + "id": 1587, + "logprob": -0.51464844, + "special": false, + "text": " does" + }, + { + "id": 433, + "logprob": -0.043182373, + "special": false, + "text": " it" + }, + { + "id": 1782, + "logprob": -1.0810547, + "special": false, + "text": " differ" + }, + { + "id": 505, + "logprob": -0.005054474, + "special": false, + "text": " from" + }, + { + "id": 8776, + "logprob": -0.47485352, + "special": false, + "text": " traditional" + }, + { + "id": 5780, + "logprob": -0.15112305, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0011291504, + "special": false, + "text": " learning" + }, + { + "id": 5380, + "logprob": -0.3173828, + "special": false, + "text": "?\n" + } + ], + "top_tokens": null + }, + "generated_text": " and how does it differ from traditional machine learning?\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -6.3867188, + "text": "What" + }, + { + "id": 374, + "logprob": -1.1318359, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.6875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.3007812, + "text": " learning" + }, + { + "id": 30, + "logprob": -2.4902344, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 323, + "logprob": -1.1171875, + "special": false, + "text": " and" + }, + { + "id": 1268, + "logprob": -0.9477539, + "special": false, + "text": " how" + }, + { + "id": 1587, + "logprob": -0.51464844, + "special": false, + "text": " does" + }, + { + "id": 433, + "logprob": -0.043182373, + "special": false, + "text": " it" + }, + { + "id": 1782, + "logprob": -1.0810547, + "special": false, + "text": " differ" + }, + { + "id": 505, + "logprob": -0.005054474, + "special": false, + "text": " from" + }, + { + "id": 8776, + "logprob": -0.47485352, + "special": false, + "text": " traditional" + }, + { + "id": 5780, + "logprob": -0.15112305, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0011291504, + "special": false, + "text": " learning" + }, + { + "id": 5380, + "logprob": -0.3173828, + "special": false, + "text": "?\n" + } + ], + "top_tokens": null + }, + "generated_text": " and how does it differ from traditional machine learning?\n" + } +] diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json new file mode 100644 index 00000000..1fb53c25 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.5, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18183, + "logprob": -1.6669922, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.08959961, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.14685059, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.125, + "special": false, + "text": " a" + }, + { + "id": 25993, + "logprob": -0.81640625, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.0013418198, + "special": false, + "text": " of" + }, + { + "id": 5662, + "logprob": -0.16027832, + "special": false, + "text": " machine" + }, + { + "id": 6832, + "logprob": -0.0016393661, + "special": false, + "text": " learning" + }, + { + "id": 429, + "logprob": -0.4477539, + "special": false, + "text": " that" + }, + { + "id": 5711, + "logprob": -1.2802734, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that uses" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json new file mode 100644 index 00000000..ca665b83 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json @@ -0,0 +1,94 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + } + ], + "seed": 0, + "tokens": [ + { + "id": 1939, + "logprob": -2.2675781, + "special": false, + "text": "?\n\n" + }, + { + "id": 33464, + "logprob": 0.0, + "special": false, + "text": "Deep" + }, + { + "id": 20909, + "logprob": -0.37695312, + "special": false, + "text": " Learning" + }, + { + "id": 4102, + "logprob": -1.9316406, + "special": false, + "text": " " + }, + { + "id": 285, + "logprob": 0.0, + "special": false, + "text": "is" + }, + { + "id": 458, + "logprob": -0.80859375, + "special": false, + "text": " an" + }, + { + "id": 3082, + "logprob": -1.4541016, + "special": false, + "text": " area" + }, + { + "id": 315, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 20443, + "logprob": -0.5136719, + "special": false, + "text": " artificial" + }, + { + "id": 11229, + "logprob": 0.0, + "special": false, + "text": " intelligence" + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" +} diff --git a/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json new file mode 100644 index 00000000..3ebeabf2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.5, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18183, + "logprob": -1.6669922, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.08959961, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.14685059, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.125, + "special": false, + "text": " a" + }, + { + "id": 25993, + "logprob": -0.81640625, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.0013418198, + "special": false, + "text": " of" + }, + { + "id": 5662, + "logprob": -0.16259766, + "special": false, + "text": " machine" + }, + { + "id": 6832, + "logprob": -0.0016393661, + "special": false, + "text": " learning" + }, + { + "id": 429, + "logprob": -0.4477539, + "special": false, + "text": " that" + }, + { + "id": 5711, + "logprob": -1.2802734, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.5, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18183, + "logprob": -1.6669922, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.08959961, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.14685059, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.125, + "special": false, + "text": " a" + }, + { + "id": 25993, + "logprob": -0.81640625, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.0013418198, + "special": false, + "text": " of" + }, + { + "id": 5662, + "logprob": -0.16259766, + "special": false, + "text": " machine" + }, + { + "id": 6832, + "logprob": -0.0016393661, + "special": false, + "text": " learning" + }, + { + "id": 429, + "logprob": -0.4477539, + "special": false, + "text": " that" + }, + { + "id": 5711, + "logprob": -1.2802734, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.5, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18183, + "logprob": -1.6669922, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.08959961, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.14685059, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.125, + "special": false, + "text": " a" + }, + { + "id": 25993, + "logprob": -0.81640625, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.0013418198, + "special": false, + "text": " of" + }, + { + "id": 5662, + "logprob": -0.16259766, + "special": false, + "text": " machine" + }, + { + "id": 6832, + "logprob": -0.0016393661, + "special": false, + "text": " learning" + }, + { + "id": 429, + "logprob": -0.4477539, + "special": false, + "text": " that" + }, + { + "id": 5711, + "logprob": -1.2802734, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that uses" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 3838, + "logprob": null, + "text": "What" + }, + { + "id": 374, + "logprob": -8.59375, + "text": " is" + }, + { + "id": 5538, + "logprob": -10.921875, + "text": " deep" + }, + { + "id": 6832, + "logprob": -0.56347656, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.5, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18183, + "logprob": -1.6669922, + "special": false, + "text": " Deep" + }, + { + "id": 6832, + "logprob": -0.08959961, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.14685059, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.125, + "special": false, + "text": " a" + }, + { + "id": 25993, + "logprob": -0.81640625, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.0013418198, + "special": false, + "text": " of" + }, + { + "id": 5662, + "logprob": -0.16259766, + "special": false, + "text": " machine" + }, + { + "id": 6832, + "logprob": -0.0016393661, + "special": false, + "text": " learning" + }, + { + "id": 429, + "logprob": -0.4477539, + "special": false, + "text": " that" + }, + { + "id": 5711, + "logprob": -1.2802734, + "special": false, + "text": " uses" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that uses" + } +] diff --git a/integration-tests/models/test_compressed_tensors_w8a8_int.py b/integration-tests/models/test_compressed_tensors_w8a8_int.py new file mode 100644 index 00000000..ca7829c0 --- /dev/null +++ b/integration-tests/models/test_compressed_tensors_w8a8_int.py @@ -0,0 +1,90 @@ +import pytest + + +@pytest.fixture(scope="module") +def compressed_tensors_w8a8_int_handle(launcher): + with launcher( + "neuralmagic/Llama-3.2-3B-Instruct-quantized.w8a8", + num_shard=2, + quantize="compressed-tensors", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def compressed_tensors_w8a8_int(compressed_tensors_w8a8_int_handle): + await compressed_tensors_w8a8_int_handle.health(300) + return compressed_tensors_w8a8_int_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int( + compressed_tensors_w8a8_int, response_snapshot +): + response = await compressed_tensors_w8a8_int.generate( + "What is deep learning?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert ( + response.generated_text + == " and how does it differ from traditional machine learning?\n" + ) + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int_all_params( + compressed_tensors_w8a8_int, response_snapshot +): + response = await compressed_tensors_w8a8_int.generate( + "What is deep learning", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is deep learning?\nDeep learning, also known as neural network or" + ) + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int_load( + compressed_tensors_w8a8_int, generate_load, response_snapshot +): + responses = await generate_load( + compressed_tensors_w8a8_int, + "What is deep learning?", + max_new_tokens=10, + n=4, + ) + + assert ( + responses[0].generated_text + == " and how does it differ from traditional machine learning?\n" + ) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py new file mode 100644 index 00000000..7cc82a4e --- /dev/null +++ b/integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py @@ -0,0 +1,92 @@ +import pytest + + +@pytest.fixture(scope="module") +def compressed_tensors_w8a8_int_dynamic_weight_handle(launcher): + with launcher( + "danieldk/Qwen2.5-1.5B-Instruct-w8a8-int-dynamic-weight", + num_shard=2, + quantize="compressed-tensors", + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def compressed_tensors_w8a8_int_dynamic_weight( + compressed_tensors_w8a8_int_dynamic_weight_handle, +): + await compressed_tensors_w8a8_int_dynamic_weight_handle.health(300) + return compressed_tensors_w8a8_int_dynamic_weight_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int_dynamic_weight( + compressed_tensors_w8a8_int_dynamic_weight, response_snapshot +): + response = await compressed_tensors_w8a8_int_dynamic_weight.generate( + "What is deep learning?", + max_new_tokens=10, + decoder_input_details=True, + ) + + assert ( + response.generated_text + == " Deep learning is a subset of machine learning that uses" + ) + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int_dynamic_weight_all_params( + compressed_tensors_w8a8_int_dynamic_weight, response_snapshot +): + response = await compressed_tensors_w8a8_int_dynamic_weight.generate( + "What is deep learning", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is deep learning?\n\nDeep Learning is an area of artificial intelligence" + ) + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_compressed_tensors_w8a8_int_dynamic_weight_load( + compressed_tensors_w8a8_int_dynamic_weight, generate_load, response_snapshot +): + responses = await generate_load( + compressed_tensors_w8a8_int_dynamic_weight, + "What is deep learning?", + max_new_tokens=10, + n=4, + ) + + assert ( + responses[0].generated_text + == " Deep learning is a subset of machine learning that uses" + ) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/poetry.lock b/server/poetry.lock index 34656816..b3f75a45 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1288,12 +1288,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.3.1" +version = "0.3.5" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"}, + {file = "marlin_kernels-0.3.5+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:58d4bf0aa1a9533acc05f1e5bf50f727ed0129848d1fa1feb2c5c3fa482518d4"}, ] [package.dependencies] @@ -1301,16 +1301,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.1" +version = "0.3.5" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"}, + {file = "marlin_kernels-0.3.5+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:a3a3653e6908db013ca96979a5ee1f6a8bb590ee7506a129e06b87d4a8cbb87d"}, ] [package.dependencies] @@ -1318,16 +1318,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.1" +version = "0.3.5" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"}, + {file = "marlin_kernels-0.3.5+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:967b4765a591530a4b9160ae32f3f352a89ae4c71daf43220c99976987d76723"}, ] [package.dependencies] @@ -1335,16 +1335,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.1" +version = "0.3.5" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"}, + {file = "marlin_kernels-0.3.5+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:fbe607d5afd1e1fca6e294c3594a0ec279d1f9ea6a2fdf7f34ccb6180d15e195"}, ] [package.dependencies] @@ -1352,7 +1352,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" @@ -4066,4 +4066,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "05add88628d836faceae1a26fde4092651a6eca74555ae38ebff879a7895be7e" +content-hash = "b889115cee7f1969856f233e74721965f692e40d2a1c2fceccaf6b3bdb19680d" diff --git a/server/pyproject.toml b/server/pyproject.toml index f039ca8a..194b04da 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -48,10 +48,10 @@ attention-kernels = [ { url = "https://github.com/danieldk/attention-kernels/releases/download/v0.1.1/attention_kernels-0.1.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.5/marlin_kernels-0.3.5+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/text_generation_server/layers/compressed_tensors/loader.py b/server/text_generation_server/layers/compressed_tensors/loader.py index e5ad3529..957277bf 100644 --- a/server/text_generation_server/layers/compressed_tensors/loader.py +++ b/server/text_generation_server/layers/compressed_tensors/loader.py @@ -12,6 +12,7 @@ from pydantic import ValidationError from torch import nn from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader +from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( @@ -151,6 +152,17 @@ class CompressedTensorsLoader(WeightsLoader): ): # INT W4A16 or W8A16 (GPTQ/AWQ-like). return WNA16Loader(weights) + elif ( + format + in { + CompressionFormat.int_quantized.value, + CompressionFormat.naive_quantized.value, + } + and weights is not None + and weights.type == QuantizationType.INT + and weights.num_bits == 8 + ): + return W8A8IntLoader(input_args=input_activations, weight_args=weights) else: raise ValueError( f"Group '{group_name}' has unsupported compressed-tensors configurtion" diff --git a/server/text_generation_server/layers/compressed_tensors/w8a8_int.py b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py new file mode 100644 index 00000000..fc6d81e4 --- /dev/null +++ b/server/text_generation_server/layers/compressed_tensors/w8a8_int.py @@ -0,0 +1,241 @@ +from typing import List, Optional, Union, TypeVar +from dataclasses import dataclass + +from loguru import logger +import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationType + +from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader + +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + +class W8A8IntLoader(WeightsLoader): + """ + Loader for w8a8 integer compressed-tensors parameters. + """ + + def __init__( + self, + *, + input_args: Optional[QuantizationArgs], + weight_args: QuantizationArgs, + ): + if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8: + raise ValueError( + f"{type(self).__name__} only supports w8a8 int checkpoints" + ) + + if not weight_args.symmetric: + raise ValueError("Checkpoints with asymmetric weights are not supported") + + self.load_weight_scale = not weight_args.dynamic + + if input_args is not None: + self.input_symmetric = input_args.symmetric + + if not input_args.dynamic: + log_once( + logger.warning, + "Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).", + ) + else: + self.input_symmetric = True + + def __str__(self) -> str: + def scale_to_str(scale): + return "static" if scale else "dynamic" + + def symmetric_to_str(symmetric): + return "symmetric" if symmetric else "asymmetric" + + return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))" + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight", to_dtype=False) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False + ) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + if weight_scale.numel() > 1: + weight_scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", + dim=0, + block_sizes=block_sizes, + to_dtype=False, + ) + weight_scale = weight_scale.reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes + ] + shapes = [x.shape for x in w] + + w = torch.cat(w, dim=dim) + + weight_scale = None + if self.load_weight_scale: + weight_scale = [ + _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) + for p, shape in zip(prefixes, shapes) + ] + weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False) + + weight_scale = None + if self.load_weight_scale: + weight_scale = weights.get_tensor( + f"{prefix}.weight_scale", to_dtype=False + ).reshape(-1) + + return Int8Weight( + input_symmetric=self.input_symmetric, + weight=w, + weight_scale=weight_scale, + ) + + +OtherT = TypeVar("OtherT") + + +def _get_tensor_or_else( + weights: Weights, prefix: str, other: OtherT +) -> Union[torch.Tensor, OtherT]: + # Even if a checkpoint uses e.g. zero-points, they can be elided: + # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105 + if weights.has_tensor(prefix): + return weights.get_tensor(prefix, to_dtype=False) + else: + return other + + +@dataclass +class Int8Weight(Weight): + input_symmetric: bool + weight: torch.Tensor + weight_scale: Optional[torch.Tensor] + + def get_linear(self, bias: torch.Tensor): + if self.weight_scale is None: + assert marlin_kernels is not None + qweight, weight_scale, _ = marlin_kernels.scaled_int8_quant(self.weight) + return W8A8IntLinear( + bias=bias, + input_symmetric=self.input_symmetric, + weight=qweight, + weight_scale=weight_scale, + ) + else: + return W8A8IntLinear( + bias=bias, + input_symmetric=self.input_symmetric, + weight=self.weight, + weight_scale=self.weight_scale, + ) + + +class W8A8IntLinear(torch.nn.Module): + def __init__( + self, + *, + bias: Optional[torch.Tensor], + input_symmetric: bool, + weight: torch.Tensor, + weight_scale: torch.Tensor, + ): + super().__init__() + + weight_scale = weight_scale.to(torch.float32) + + self.bias = bias + self.input_symmetric = input_symmetric + # cutlass kernels require transposed weights. + self.weight = weight.t() + self.weight_scale = weight_scale + + if input_symmetric: + self.zero_point_adj = None + else: + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp + self.zero_point_adj = self.weight.sum( + dim=0, keepdim=True, dtype=torch.int32 + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + assert marlin_kernels is not None + + qinput, input_scale, input_zero_point = marlin_kernels.scaled_int8_quant( + input=input, + scale=None, + azp=None, + symmetric=self.input_symmetric, + ) + + if self.input_symmetric: + return marlin_kernels.cutlass_scaled_mm( + a=qinput, + b=self.weight, + scale_a=input_scale, + scale_b=self.weight_scale, + out_dtype=input.dtype, + bias=self.bias, + ) + else: + assert ( + self.zero_point_adj is not None + and input_scale is not None + and (self.input_symmetric or input_zero_point is not None) + ) + + return marlin_kernels.cutlass_scaled_mm_azp( + a=qinput, + b=self.weight, + scale_a=input_scale, + scale_b=self.weight_scale, + out_dtype=input.dtype, + azp_adj=self.zero_point_adj, + azp=input_zero_point, + bias=self.bias, + ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index aae64acf..c03dd2b0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -220,6 +220,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64, @@ -255,7 +256,8 @@ class Weights: # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. if ( - tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + tensor.dtype + not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32) and to_dtype ): tensor = tensor.to(dtype=self.dtype) @@ -331,6 +333,7 @@ class Weights: tensor.dtype not in [ torch.float8_e4m3fn, + torch.int8, torch.int16, torch.int32, torch.int64,