From a4e5801684ea2b34bc14dbbacffc08ca2f7f71af Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Fri, 9 Feb 2024 10:45:16 +0100 Subject: [PATCH] ROCm AWQ support (#1514) # What does this PR do? This PR adds the possibility to run AWQ models with Exllama/GPTQ kernels, specifically for ROCm devices that support Exllama kernels but not AWQ's GEMM. This is done by : - un-packing, reordering and re-packing AWQ weights when `--quantize gptq` but the model's `quant_method=awq`. - avoiding overflows when adding 1 to zeros in exllama and triton. Ref: https://github.com/casper-hansen/AutoAWQ/pull/313 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: Nicolas Patry --- .../test_flash_llama_gptq.json | 69 ++-- .../test_flash_llama_gptq_all_params.json | 25 +- .../test_flash_llama_gptq_load.json | 276 ++++++++------- .../test_flash_starcoder_gptq.json | 247 ++++++------- ...t_flash_starcoder_gptq_default_params.json | 63 ++-- .../test_flash_starcoder_gptq_load.json | 332 +++++++++--------- .../exllama_kernels/cuda_func/q4_matmul.cu | 8 +- .../exllama_kernels/cuda_func/q4_matrix.cu | 2 +- .../cuda/q_gemm_kernel_gptq.cuh | 16 +- .../exllamav2_kernels/cuda/q_matrix.cu | 16 +- .../flash_santacoder_modeling.py | 14 +- .../utils/awq/conversion_utils.py | 97 +++++ .../utils/gptq/quant_linear.py | 2 +- server/text_generation_server/utils/layers.py | 7 + .../text_generation_server/utils/weights.py | 114 ++++-- 15 files changed, 737 insertions(+), 551 deletions(-) create mode 100644 server/text_generation_server/utils/awq/conversion_utils.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json index e4ffb83b..7797cc6c 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -11,78 +11,79 @@ }, { "id": 4321, - "logprob": -9.59375, + "logprob": -9.7890625, "text": "Test" }, { "id": 2009, - "logprob": -9.6640625, + "logprob": -9.625, "text": "request" } ], "seed": null, "tokens": [ - { - "id": 29918, - "logprob": -2.3867188, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -2.8183594, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -1.6367188, + "logprob": -2.3359375, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -1.0527344, + "logprob": -1.8779297, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.6542969, + "logprob": -1.2744141, "special": false, "text": " request" }, - { - "id": 29918, - "logprob": -0.056121826, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -0.01600647, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -0.87939453, + "logprob": -1.6933594, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -0.7529297, + "logprob": -1.4648438, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.2980957, + "logprob": -0.15600586, "special": false, "text": " request" + }, + { + "id": 13, + "logprob": -0.8027344, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.23022461, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0069885254, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.02218628, + "special": false, + "text": "\n" } - ] + ], + "top_tokens": null }, - "generated_text": "_uri\nTest request_uri\nTest request" + "generated_text": "\nTest request\nTest request\nTest request\n" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json index 02713a00..fa2fd4a2 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -11,12 +11,12 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.84375, "text": "Test" }, { "id": 2009, - "logprob": -9.6640625, + "logprob": -9.6015625, "text": "request" } ], @@ -24,13 +24,13 @@ "tokens": [ { "id": 29899, - "logprob": -1.1640625, + "logprob": -1.5625, "special": false, "text": "-" }, { "id": 1454, - "logprob": -0.07543945, + "logprob": -0.20410156, "special": false, "text": "for" }, @@ -54,19 +54,19 @@ }, { "id": 396, - "logprob": -0.2956543, + "logprob": -0.27685547, "special": false, "text": " #" }, { "id": 29906, - "logprob": -0.52734375, + "logprob": -0.4970703, "special": false, "text": "2" }, { "id": 29900, - "logprob": -0.6899414, + "logprob": -0.80615234, "special": false, "text": "0" }, @@ -77,12 +77,13 @@ "text": "1" }, { - "id": 29946, - "logprob": -1.5068359, + "id": 29955, + "logprob": -1.0751953, "special": false, - "text": "4" + "text": "7" } - ] + ], + "top_tokens": null }, - "generated_text": "Test request-for-comment: #2014" + "generated_text": "Test request-for-comment: #2017" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json index 88bfa4f9..594b7351 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -12,80 +12,81 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.828125, "text": "Test" }, { "id": 2009, - "logprob": -9.671875, + "logprob": -9.609375, "text": "request" } ], "seed": null, "tokens": [ - { - "id": 29918, - "logprob": -2.3828125, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -2.8105469, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -1.6396484, + "logprob": -2.3300781, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -1.0546875, + "logprob": -1.8740234, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.6513672, + "logprob": -1.2646484, "special": false, "text": " request" }, - { - "id": 29918, - "logprob": -0.056365967, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -0.016082764, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -0.87841797, + "logprob": -1.7158203, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -0.7548828, + "logprob": -1.4667969, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.29711914, + "logprob": -0.15344238, "special": false, "text": " request" + }, + { + "id": 13, + "logprob": -0.81591797, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22973633, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007045746, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021957397, + "special": false, + "text": "\n" } - ] + ], + "top_tokens": null }, - "generated_text": "_uri\nTest request_uri\nTest request" + "generated_text": "\nTest request\nTest request\nTest request\n" }, { "details": { @@ -100,80 +101,81 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.84375, "text": "Test" }, { "id": 2009, - "logprob": -9.6640625, + "logprob": -9.59375, "text": "request" } ], "seed": null, "tokens": [ - { - "id": 29918, - "logprob": -2.3828125, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -2.828125, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -1.6386719, + "logprob": -2.3378906, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -1.0527344, + "logprob": -1.8779297, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.6542969, + "logprob": -1.2636719, "special": false, "text": " request" }, - { - "id": 29918, - "logprob": -0.055877686, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -0.016021729, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -0.8769531, + "logprob": -1.6992188, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -0.7583008, + "logprob": -1.4589844, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.29833984, + "logprob": -0.15344238, "special": false, "text": " request" + }, + { + "id": 13, + "logprob": -0.79052734, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22937012, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.007041931, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.022140503, + "special": false, + "text": "\n" } - ] + ], + "top_tokens": null }, - "generated_text": "_uri\nTest request_uri\nTest request" + "generated_text": "\nTest request\nTest request\nTest request\n" }, { "details": { @@ -188,80 +190,81 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.84375, "text": "Test" }, { "id": 2009, - "logprob": -9.671875, + "logprob": -9.609375, "text": "request" } ], "seed": null, "tokens": [ - { - "id": 29918, - "logprob": -2.3847656, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -2.8144531, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -1.6396484, + "logprob": -2.3261719, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -1.0527344, + "logprob": -1.8730469, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.65478516, + "logprob": -1.2587891, "special": false, "text": " request" }, - { - "id": 29918, - "logprob": -0.056243896, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -0.016143799, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -0.8808594, + "logprob": -1.6894531, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -0.75341797, + "logprob": -1.46875, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.2956543, + "logprob": -0.1541748, "special": false, "text": " request" + }, + { + "id": 13, + "logprob": -0.80322266, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22912598, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.0070495605, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021606445, + "special": false, + "text": "\n" } - ] + ], + "top_tokens": null }, - "generated_text": "_uri\nTest request_uri\nTest request" + "generated_text": "\nTest request\nTest request\nTest request\n" }, { "details": { @@ -276,79 +279,80 @@ }, { "id": 4321, - "logprob": -9.6015625, + "logprob": -9.84375, "text": "Test" }, { "id": 2009, - "logprob": -9.6640625, + "logprob": -9.6015625, "text": "request" } ], "seed": null, "tokens": [ - { - "id": 29918, - "logprob": -2.3769531, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -2.8183594, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -1.6396484, + "logprob": -2.3320312, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -1.0546875, + "logprob": -1.875, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.65478516, + "logprob": -1.2646484, "special": false, "text": " request" }, - { - "id": 29918, - "logprob": -0.05557251, - "special": false, - "text": "_" - }, - { - "id": 5338, - "logprob": -0.01612854, - "special": false, - "text": "uri" - }, { "id": 13, - "logprob": -0.8730469, + "logprob": -1.6884766, "special": false, "text": "\n" }, { "id": 3057, - "logprob": -0.7519531, + "logprob": -1.4589844, "special": false, "text": "Test" }, { "id": 2009, - "logprob": -0.29785156, + "logprob": -0.15185547, "special": false, "text": " request" + }, + { + "id": 13, + "logprob": -0.79833984, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.22827148, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.006996155, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.021560669, + "special": false, + "text": "\n" } - ] + ], + "top_tokens": null }, - "generated_text": "_uri\nTest request_uri\nTest request" + "generated_text": "\nTest request\nTest request\nTest request\n" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json index 53055e42..5e537bb7 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -1,193 +1,194 @@ { - "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", "details": { "best_of_sequences": null, "finish_reason": "length", "generated_tokens": 20, - "seed": null, "prefill": [ { "id": 589, - "text": "def", - "logprob": null + "logprob": null, + "text": "def" }, { "id": 3226, - "text": " ge", - "logprob": -9.0234375 + "logprob": -8.5859375, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -9.0859375 + "logprob": -7.5859375, + "text": "ometric" }, { "id": 81, - "text": "_", - "logprob": -0.25878906 + "logprob": -0.2668457, + "text": "_" }, { "id": 6009, - "text": "mean", - "logprob": -2.2109375 + "logprob": -1.6416016, + "text": "mean" }, { "id": 26, - "text": "(", - "logprob": -0.30371094 + "logprob": -0.22705078, + "text": "(" }, { "id": 62, - "text": "L", - "logprob": -5.6054688 + "logprob": -5.2304688, + "text": "L" }, { "id": 44, - "text": ":", - "logprob": -3.0722656 + "logprob": -3.0976562, + "text": ":" }, { "id": 1682, - "text": " List", - "logprob": -0.6879883 + "logprob": -1.1044922, + "text": " List" }, { "id": 77, - "text": "[", - "logprob": -0.38500977 + "logprob": -0.14294434, + "text": "[" }, { "id": 1808, - "text": "float", - "logprob": -0.984375 + "logprob": -0.32299805, + "text": "float" }, { "id": 10794, - "text": "]):", - "logprob": -2.5351562 + "logprob": -2.8164062, + "text": "]):" } ], + "seed": null, "tokens": [ { "id": 284, - "text": "\n ", - "logprob": -1.1738281, - "special": false + "logprob": -0.1282959, + "special": false, + "text": "\n " }, { - "id": 442, - "text": " return", - "logprob": -0.95947266, - "special": false + "id": 1524, + "logprob": -0.97998047, + "special": false, + "text": " \"\"\"" }, { - "id": 3632, - "text": " sum", - "logprob": -1.4199219, - "special": false + "id": 284, + "logprob": -0.7006836, + "special": false, + "text": "\n " }, { - "id": 26, - "text": "(", - "logprob": -0.085876465, - "special": false + "id": 14883, + "logprob": -2.1933594, + "special": false, + "text": " Calculate" }, { - "id": 62, - "text": "L", - "logprob": -0.09875488, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.30517578, - "special": false - }, - { - "id": 517, - "text": " /", - "logprob": -0.42089844, - "special": false - }, - { - "id": 2069, - "text": " len", - "logprob": -0.042053223, - "special": false - }, - { - "id": 26, - "text": "(", - "logprob": -0.0011806488, - "special": false - }, - { - "id": 62, - "text": "L", - "logprob": -0.0005259514, - "special": false - }, - { - "id": 27, - "text": ")", - "logprob": -0.0017633438, - "special": false - }, - { - "id": 478, - "text": "\n\n", - "logprob": -0.69189453, - "special": false - }, - { - "id": 203, - "text": "\n", - "logprob": -0.041870117, - "special": false - }, - { - "id": 589, - "text": "def", - "logprob": -0.27856445, - "special": false + "id": 322, + "logprob": -0.2697754, + "special": false, + "text": " the" }, { "id": 3226, - "text": " ge", - "logprob": -1.7255859, - "special": false + "logprob": -0.0836792, + "special": false, + "text": " ge" }, { "id": 21017, - "text": "ometric", - "logprob": -0.011291504, - "special": false + "logprob": -0.018737793, + "special": false, + "text": "ometric" }, { - "id": 81, - "text": "_", - "logprob": -0.008430481, - "special": false + "id": 5651, + "logprob": -0.028640747, + "special": false, + "text": " mean" }, { - "id": 6009, - "text": "mean", - "logprob": -0.025787354, - "special": false + "id": 432, + "logprob": -0.29467773, + "special": false, + "text": " of" }, { - "id": 26, - "text": "(", - "logprob": -0.073913574, - "special": false + "id": 312, + "logprob": -0.31518555, + "special": false, + "text": " a" }, { - "id": 62, - "text": "L", - "logprob": -0.09967041, - "special": false + "id": 1149, + "logprob": -0.20605469, + "special": false, + "text": " list" + }, + { + "id": 432, + "logprob": -0.23254395, + "special": false, + "text": " of" + }, + { + "id": 7515, + "logprob": -0.4489746, + "special": false, + "text": " numbers" + }, + { + "id": 32, + "logprob": -0.6044922, + "special": false, + "text": "." + }, + { + "id": 446, + "logprob": -0.63964844, + "special": false, + "text": "\n\n " + }, + { + "id": 499, + "logprob": -1.1953125, + "special": false, + "text": " :" + }, + { + "id": 753, + "logprob": -0.03515625, + "special": false, + "text": "param" + }, + { + "id": 498, + "logprob": -0.06311035, + "special": false, + "text": " L" + }, + { + "id": 44, + "logprob": -0.003414154, + "special": false, + "text": ":" + }, + { + "id": 1682, + "logprob": -1.3310547, + "special": false, + "text": " List" } - ] - } + ], + "top_tokens": null + }, + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json index 1ace3814..bf0f5146 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -11,57 +11,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5898438, "text": "ometric" }, { "id": 81, - "logprob": -0.25830078, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.1875, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30004883, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6171875, + "logprob": -5.2382812, "text": "L" }, { "id": 44, - "logprob": -3.078125, + "logprob": -3.0996094, "text": ":" }, { "id": 1682, - "logprob": -0.68066406, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.38745117, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.9453125, + "logprob": -0.32226562, "text": "float" }, { "id": 10794, - "logprob": -2.5371094, + "logprob": -2.8164062, "text": "]):" } ], @@ -69,19 +69,19 @@ "tokens": [ { "id": 284, - "logprob": -0.051635742, + "logprob": 0.0, "special": false, "text": "\n " }, { "id": 442, - "logprob": 0.0, + "logprob": -1.3134766, "special": false, "text": " return" }, { "id": 11665, - "logprob": -1.2236328, + "logprob": -0.10021973, "special": false, "text": " reduce" }, @@ -129,7 +129,7 @@ }, { "id": 319, - "logprob": 0.0, + "logprob": -0.42871094, "special": false, "text": " *" }, @@ -158,36 +158,37 @@ "text": ")" }, { - "id": 203, - "logprob": -0.12695312, - "special": false, - "text": "\n" - }, - { - "id": 203, + "id": 1115, "logprob": 0.0, "special": false, - "text": "\n" + "text": " **" }, { - "id": 589, + "id": 308, "logprob": 0.0, "special": false, - "text": "def" + "text": " (" }, { - "id": 3226, + "id": 35, "logprob": 0.0, "special": false, - "text": " ge" + "text": "1" }, { - "id": 21017, + "id": 32, + "logprob": -0.31323242, + "special": false, + "text": "." + }, + { + "id": 34, "logprob": 0.0, "special": false, - "text": "ometric" + "text": "0" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric" + "generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json index 5381ce5a..46a21ed8 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -12,57 +12,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5820312, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26708984, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22717285, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1015625, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1083984, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -70,67 +70,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12817383, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91796875, + "id": 1524, + "logprob": -0.9863281, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3291016, + "id": 284, + "logprob": -0.7011719, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.097717285, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.29003906, + "id": 3226, + "logprob": -0.08465576, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.03829956, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011987686, + "id": 432, + "logprob": -0.29418945, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -145,57 +146,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.59375, "text": "ometric" }, { "id": 81, - "logprob": -0.25878906, + "logprob": -0.26953125, "text": "_" }, { "id": 6009, - "logprob": -2.2109375, + "logprob": -1.640625, "text": "mean" }, { "id": 26, - "logprob": -0.30371094, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.6054688, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0722656, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6879883, + "logprob": -1.1123047, "text": " List" }, { "id": 77, - "logprob": -0.38500977, + "logprob": -0.14294434, "text": "[" }, { "id": 1808, - "logprob": -0.984375, + "logprob": -0.32299805, "text": "float" }, { "id": 10794, - "logprob": -2.5351562, + "logprob": -2.8164062, "text": "]):" } ], @@ -203,67 +204,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1738281, + "logprob": -0.12854004, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9584961, + "id": 1524, + "logprob": -0.9897461, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.4169922, + "id": 284, + "logprob": -0.69970703, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.085876465, + "id": 14883, + "logprob": -2.2050781, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.0982666, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.3022461, + "id": 3226, + "logprob": -0.08496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.40504883, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.041656494, + "id": 5651, + "logprob": -0.029037476, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011844635, + "id": 432, + "logprob": -0.2939453, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005264282, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -278,57 +280,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26586914, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6347656, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22766113, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.2265625, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.0976562, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1025391, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.1427002, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.32592773, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8164062, "text": "]):" } ], @@ -336,67 +338,68 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.13012695, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.9165039, + "id": 1524, + "logprob": -0.98046875, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.328125, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.07946777, + "id": 14883, + "logprob": -2.1992188, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09820557, + "id": 322, + "logprob": -0.2668457, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28930664, + "id": 3226, + "logprob": -0.083496094, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.34592773, + "id": 21017, + "logprob": -0.01902771, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038330078, + "id": 5651, + "logprob": -0.029006958, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011940002, + "id": 432, + "logprob": -0.29248047, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.00050878525, + "id": 312, + "logprob": -0.3161621, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" }, { "details": { @@ -411,57 +414,57 @@ }, { "id": 3226, - "logprob": -9.0234375, + "logprob": -8.5859375, "text": " ge" }, { "id": 21017, - "logprob": -9.0859375, + "logprob": -7.5859375, "text": "ometric" }, { "id": 81, - "logprob": -0.25927734, + "logprob": -0.26904297, "text": "_" }, { "id": 6009, - "logprob": -2.25, + "logprob": -1.6386719, "text": "mean" }, { "id": 26, - "logprob": -0.30126953, + "logprob": -0.22705078, "text": "(" }, { "id": 62, - "logprob": -5.7539062, + "logprob": -5.234375, "text": "L" }, { "id": 44, - "logprob": -3.0878906, + "logprob": -3.1132812, "text": ":" }, { "id": 1682, - "logprob": -0.6845703, + "logprob": -1.1074219, "text": " List" }, { "id": 77, - "logprob": -0.3918457, + "logprob": -0.14477539, "text": "[" }, { "id": 1808, - "logprob": -0.8798828, + "logprob": -0.3256836, "text": "float" }, { "id": 10794, - "logprob": -2.4980469, + "logprob": -2.8027344, "text": "]):" } ], @@ -469,66 +472,67 @@ "tokens": [ { "id": 284, - "logprob": -1.1533203, + "logprob": -0.12915039, "special": false, "text": "\n " }, { - "id": 442, - "logprob": -0.91259766, + "id": 1524, + "logprob": -0.98535156, "special": false, - "text": " return" + "text": " \"\"\"" }, { - "id": 3632, - "logprob": -1.3251953, + "id": 284, + "logprob": -0.69921875, "special": false, - "text": " sum" + "text": "\n " }, { - "id": 26, - "logprob": -0.08062744, + "id": 14883, + "logprob": -2.2011719, "special": false, - "text": "(" + "text": " Calculate" }, { - "id": 62, - "logprob": -0.09906006, + "id": 322, + "logprob": -0.26708984, "special": false, - "text": "L" + "text": " the" }, { - "id": 27, - "logprob": -0.28979492, + "id": 3226, + "logprob": -0.08502197, "special": false, - "text": ")" + "text": " ge" }, { - "id": 517, - "logprob": -0.35958984, + "id": 21017, + "logprob": -0.019012451, "special": false, - "text": " /" + "text": "ometric" }, { - "id": 2069, - "logprob": -0.038604736, + "id": 5651, + "logprob": -0.028625488, "special": false, - "text": " len" + "text": " mean" }, { - "id": 26, - "logprob": -0.0011901855, + "id": 432, + "logprob": -0.29589844, "special": false, - "text": "(" + "text": " of" }, { - "id": 62, - "logprob": -0.0005078316, + "id": 312, + "logprob": -0.31591797, "special": false, - "text": "L" + "text": " a" } - ] + ], + "top_tokens": null }, - "generated_text": "\n return sum(L) / len(L" + "generated_text": "\n \"\"\"\n Calculate the geometric mean of a" } ] diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu index 61380f42..09126efe 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel else { half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); @@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F; if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu index f3d1564f..2867a8d0 100644 --- a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -189,7 +189,7 @@ __global__ void reconstruct_kernel int group = row / groupsize; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F; uint32_t w_read = w_.item_uint32_t(row, column); half* out_ptr = out_.item_ptr(row, column); diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh index 74b0db2b..f816fd9d 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh @@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); // __syncthreads(); @@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } #pragma unroll diff --git a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu index ae08cc1f..7a0038b4 100644 --- a/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu +++ b/server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu @@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); __syncthreads(); @@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]); } for (int p = 0; p < 4; p++) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 22d03adf..81041046 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -69,9 +69,17 @@ def _load_multi_mqa_gptq( qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) - g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - g_idx = g_idx.to(device=weights.device) - bits, groupsize, _ = weights._get_gptq_params() + bits, groupsize, _, quant_method, = weights._get_gptq_params() + if quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") + g_idx = g_idx.to(device=weights.device) + elif quant_method == "awq": + g_idx = None + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) from text_generation_server.utils.layers import HAS_EXLLAMA diff --git a/server/text_generation_server/utils/awq/conversion_utils.py b/server/text_generation_server/utils/awq/conversion_utils.py new file mode 100644 index 00000000..b19eafbb --- /dev/null +++ b/server/text_generation_server/utils/awq/conversion_utils.py @@ -0,0 +1,97 @@ +import torch +from typing import List + + +AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def pack(imatrix: torch.Tensor, direction: str = "column"): + """ + Packs a 4-bit integer matrix into a packed 32-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of packing, either "column" or "row" + Returns: + qmatrix (torch.Tensor): packed matrix of integers + """ + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + if direction == "column": + imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4)) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1) + + elif direction == "row": + imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1) + qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1) + + qmatrix = qmatrix.to(torch.int32) + + return qmatrix + + +def unpack(qmatrix: torch.Tensor, direction: str = "column"): + """ + Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix. + Args: + qmatrix (torch.Tensor): matrix of packed integers + direction (str): direction of unpacking, either "column" or "row" + Returns: + imatrix (torch.Tensor): matrix of integers + """ + shifts = torch.arange(0, 32, 4, device=qmatrix.device) + + if direction == "column": + imatrix = torch.bitwise_right_shift( + qmatrix[:, :, None], shifts[None, None, :] + ).view(qmatrix.shape[0], -1) + + elif direction == "row": + imatrix = torch.bitwise_right_shift( + qmatrix[:, None, :], shifts[None, :, None] + ).view(-1, qmatrix.shape[-1]) + + imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow + + return imatrix + + +def apply_order( + imatrix: torch.Tensor, + direction: str = "column", + order: List[int] = AWQ_PACK_ORDER, +): + """ + Applies the order to a 4-bit integer matrix. + Args: + imatrix (torch.Tensor): matrix of integers + direction (str): direction of applying order, either "column" or "row" + order (List[int]): order to apply, default is AWQ_PACK_ORDER + Returns: + imatrix (torch.Tensor): matrix of integers + """ + if direction == "column": + imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape) + elif direction == "row": + imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape) + + return imatrix + + +def fast_awq_to_gptq(qweight, qzeros): + # awq uses column packing for both weights and zeros + izeros = unpack(qzeros, direction="column") + iweights = unpack(qweight, direction="column") + + # Reverse the order of the iweight and izeros tensors + izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER) + iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER) + # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros) + izeros = izeros - 1 + # exllama uses row packing for weights and column packing for zeros + qzeros = pack(izeros, direction="column") + qweight = pack(iweights, direction="row") + + return qweight, qzeros diff --git a/server/text_generation_server/utils/gptq/quant_linear.py b/server/text_generation_server/utils/gptq/quant_linear.py index bfc91c00..8ad0dd80 100644 --- a/server/text_generation_server/utils/gptq/quant_linear.py +++ b/server/text_generation_server/utils/gptq/quant_linear.py @@ -182,7 +182,7 @@ try: ) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq - zeros = zeros + 1 + zeros = (zeros + 1) & maxq # eventually avoid overflow a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 010d6143..01e32588 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -349,6 +349,13 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) + if IS_ROCM_SYSTEM: + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + if not HAS_AWQ: + raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly") linear = WQLinear( w_bit=bits, group_size=groupsize, diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 186733f3..8f7e1f10 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -46,7 +46,6 @@ class Weights: return self._handles[filename] def get_filename(self, tensor_name: str) -> (str, str): - names = [tensor_name] if self.prefix is not None: prefixed = f"{self.prefix}.{tensor_name}" @@ -154,15 +153,30 @@ class Weights: f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) + bits, groupsize, _, quant_method = self._get_gptq_params() + qzeros = self._get_qweight(f"{prefix}.qzeros") scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) - if quantize == "gptq": + + if quantize == "gptq" and quant_method == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") + elif quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + g_idx = ( + torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device) + // groupsize + ).to(dtype=torch.int32) else: g_idx = None - bits, groupsize, _ = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: slice_ = self._get_slice(f"{prefix}.weight") @@ -204,20 +218,40 @@ class Weights: [self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1 ) - if quantize == "gptq": - w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] - for w2 in w[1:]: - torch.testing.assert_close(w2, w[0]) - g_idx = w[0] - else: - g_idx = None + bits, groupsize, desc_act, quant_method = self._get_gptq_params() - bits, groupsize, desc_act = self._get_gptq_params() from text_generation_server.utils.layers import HAS_EXLLAMA use_exllama = ( bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act ) + + if quantize == "gptq" and quant_method == "gptq": + w = [self.get_tensor(f"{p}.g_idx") for p in prefixes] + for w2 in w[1:]: + torch.testing.assert_close(w2, w[0]) + g_idx = w[0] + elif quantize == "gptq" and quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // bits), device=qweight.device + ) + // groupsize + ).to(dtype=torch.int32) + else: + g_idx = None + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -243,7 +277,7 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": use_exllama = True - bits, groupsize, desc_act = self._get_gptq_params() + bits, groupsize, desc_act, quant_method = self._get_gptq_params() if bits != 4: use_exllama = False @@ -252,8 +286,19 @@ class Weights: log_once(logger.warning, "Disabling exllama because desc_act=True") use_exllama = False + try: + qweight = self.get_sharded(f"{prefix}.qweight", dim=0) + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if quant_method == "gptq": + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + elif quant_method == "awq": + g_idx = None + if self.process_group.size() > 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") if g_idx is not None: if ( not torch.equal( @@ -269,13 +314,6 @@ class Weights: # it would require to reorder input activations that are split unto several GPUs use_exllama = False - try: - qweight = self.get_sharded(f"{prefix}.qweight", dim=0) - except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA if use_exllama: @@ -289,8 +327,6 @@ class Weights: else: log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) @@ -298,12 +334,31 @@ class Weights: qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - if use_exllama: + if use_exllama and g_idx is not None: g_idx = g_idx - g_idx[0] + if quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.utils.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // bits), device=qweight.device + ) + // groupsize + ).to(dtype=torch.int32) + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq": - bits, groupsize, _ = self._get_gptq_params() + bits, groupsize, _, _ = self._get_gptq_params() try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) @@ -322,20 +377,22 @@ class Weights: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight - def _get_gptq_params(self) -> Tuple[int, int, int]: + def _get_gptq_params(self) -> Tuple[int, int, int, str]: try: bits = self.get_tensor("gptq_bits").item() groupsize = self.get_tensor("gptq_groupsize").item() desc_act = False + quant_method = "gptq" except (SafetensorError, RuntimeError) as e: try: bits = self.gptq_bits groupsize = self.gptq_groupsize desc_act = getattr(self, "gptq_desc_act", False) + quant_method = getattr(self, "quant_method", "gptq") except Exception: raise e - return bits, groupsize, desc_act + return bits, groupsize, desc_act, quant_method def _set_gptq_params(self, model_id, revision): filename = "config.json" @@ -351,6 +408,7 @@ class Weights: self.gptq_bits = data["quantization_config"]["bits"] self.gptq_groupsize = data["quantization_config"]["group_size"] self.gptq_desc_act = data["quantization_config"]["desc_act"] + self.quant_method = data["quantization_config"]["quant_method"] except Exception: filename = "quantize_config.json" try: @@ -365,6 +423,8 @@ class Weights: self.gptq_bits = data["bits"] self.gptq_groupsize = data["group_size"] self.gptq_desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" except Exception: filename = "quant_config.json" try: @@ -379,5 +439,7 @@ class Weights: self.gptq_bits = data["w_bit"] self.gptq_groupsize = data["q_group_size"] self.gptq_desc_act = data["desc_act"] + if "version" in data and data["version"] == "GEMM": + self.quant_method = "awq" except Exception: pass