From d5b5bc750fd8d00920cb4cb5e5abe121457d717f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 21 Jul 2023 10:59:00 +0200 Subject: [PATCH] feat(server): Add exllama GPTQ CUDA kernel support #553 (#666) Just trying to get the integration tests to pass. # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com> --- Dockerfile | 13 + Makefile | 3 + integration-tests/conftest.py | 11 +- .../test_flash_llama_gptq.json | 88 +++ .../test_flash_llama_gptq_all_params.json | 88 +++ .../test_flash_llama_gptq_load.json | 354 ++++++++++++ .../test_flash_starcoder_gptq.json | 193 +++++++ ...t_flash_starcoder_gptq_default_params.json | 193 +++++++ .../test_flash_starcoder_gptq_load.json | 534 ++++++++++++++++++ .../models/test_flash_llama_gptq.py | 57 ++ .../models/test_flash_starcoder_gptq.py | 49 ++ .../exllama_kernels/cuda_buffers.cu | 71 +++ .../exllama_kernels/cuda_buffers.cuh | 52 ++ .../exllama_kernels/cuda_compat.cuh | 58 ++ .../exllama_kernels/cuda_func/column_remap.cu | 61 ++ .../cuda_func/column_remap.cuh | 19 + .../exllama_kernels/cuda_func/q4_matmul.cu | 252 +++++++++ .../exllama_kernels/cuda_func/q4_matmul.cuh | 37 ++ .../exllama_kernels/cuda_func/q4_matrix.cu | 217 +++++++ .../exllama_kernels/cuda_func/q4_matrix.cuh | 53 ++ .../exllama_kernels/exllama_ext.cpp | 249 ++++++++ .../exllama_kernels/matrix.cuh | 294 ++++++++++ .../exllama_kernels/exllama_kernels/tuning.h | 13 + .../exllama_kernels/exllama_kernels/util.cuh | 29 + server/exllama_kernels/setup.py | 19 + .../custom_modeling/flash_llama_modeling.py | 1 - .../flash_santacoder_modeling.py | 23 +- .../models/flash_causal_lm.py | 1 - server/text_generation_server/models/model.py | 3 +- server/text_generation_server/server.py | 11 + .../utils/gptq/exllama.py | 121 ++++ server/text_generation_server/utils/layers.py | 32 +- .../text_generation_server/utils/weights.py | 96 +++- 33 files changed, 3232 insertions(+), 63 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json create mode 100644 integration-tests/models/test_flash_llama_gptq.py create mode 100644 integration-tests/models/test_flash_starcoder_gptq.py create mode 100644 server/exllama_kernels/exllama_kernels/cuda_buffers.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_buffers.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_compat.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu create mode 100644 server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh create mode 100644 server/exllama_kernels/exllama_kernels/exllama_ext.cpp create mode 100644 server/exllama_kernels/exllama_kernels/matrix.cuh create mode 100644 server/exllama_kernels/exllama_kernels/tuning.h create mode 100644 server/exllama_kernels/exllama_kernels/util.cuh create mode 100644 server/exllama_kernels/setup.py create mode 100644 server/text_generation_server/utils/gptq/exllama.py diff --git a/Dockerfile b/Dockerfile index 168f2f97..70c60132 100644 --- a/Dockerfile +++ b/Dockerfile @@ -108,6 +108,17 @@ COPY server/Makefile-flash-att-v2 Makefile # Build specific version of flash attention v2 RUN make build-flash-attention-v2 +# Build Transformers exllama kernels +FROM kernel-builder as exllama-kernels-builder + +WORKDIR /usr/src + +COPY server/exllama_kernels/ . + + +# Build specific version of transformers +RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder @@ -161,6 +172,8 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86 # Copy build artifacts from custom kernels builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# Copy build artifacts from exllama kernels builder +COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/Makefile b/Makefile index 3c2f2b9d..81b312d1 100644 --- a/Makefile +++ b/Makefile @@ -56,3 +56,6 @@ run-bloom: run-bloom-quantize: text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 + +clean: + rm -rf target aml diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 8f59d75a..3f7a24dd 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -230,15 +230,16 @@ def launcher(event_loop): shard_uds_path, ] + env = os.environ + if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) - if quantize: + if quantize is not None: args.append("--quantize") - args.append("bitsandbytes") + args.append(quantize) if trust_remote_code: args.append("--trust-remote-code") - env = os.environ env["LOG_LEVEL"] = "info,text_generation_router=debug" if not use_flash_attention: @@ -275,9 +276,9 @@ def launcher(event_loop): if num_shard is not None: args.extend(["--num-shard", str(num_shard)]) - if quantize: + if quantize is not None: args.append("--quantize") - args.append("bitsandbytes") + args.append(quantize) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json new file mode 100644 index 00000000..e4ffb83b --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json @@ -0,0 +1,88 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.59375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6640625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 29918, + "logprob": -2.3867188, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -2.8183594, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -1.6367188, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.0527344, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.6542969, + "special": false, + "text": " request" + }, + { + "id": 29918, + "logprob": -0.056121826, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -0.01600647, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -0.87939453, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.7529297, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.2980957, + "special": false, + "text": " request" + } + ] + }, + "generated_text": "_uri\nTest request_uri\nTest request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json new file mode 100644 index 00000000..02713a00 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json @@ -0,0 +1,88 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.6015625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6640625, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 29899, + "logprob": -1.1640625, + "special": false, + "text": "-" + }, + { + "id": 1454, + "logprob": -0.07543945, + "special": false, + "text": "for" + }, + { + "id": 29899, + "logprob": 0.0, + "special": false, + "text": "-" + }, + { + "id": 9342, + "logprob": 0.0, + "special": false, + "text": "comment" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 396, + "logprob": -0.2956543, + "special": false, + "text": " #" + }, + { + "id": 29906, + "logprob": -0.52734375, + "special": false, + "text": "2" + }, + { + "id": 29900, + "logprob": -0.6899414, + "special": false, + "text": "0" + }, + { + "id": 29896, + "logprob": 0.0, + "special": false, + "text": "1" + }, + { + "id": 29946, + "logprob": -1.5068359, + "special": false, + "text": "4" + } + ] + }, + "generated_text": "Test request-for-comment: #2014" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json new file mode 100644 index 00000000..88bfa4f9 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json @@ -0,0 +1,354 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.6015625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.671875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 29918, + "logprob": -2.3828125, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -2.8105469, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -1.6396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.0546875, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.6513672, + "special": false, + "text": " request" + }, + { + "id": 29918, + "logprob": -0.056365967, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -0.016082764, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -0.87841797, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.7548828, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.29711914, + "special": false, + "text": " request" + } + ] + }, + "generated_text": "_uri\nTest request_uri\nTest request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.6015625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6640625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 29918, + "logprob": -2.3828125, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -2.828125, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -1.6386719, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.0527344, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.6542969, + "special": false, + "text": " request" + }, + { + "id": 29918, + "logprob": -0.055877686, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -0.016021729, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -0.8769531, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.7583008, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.29833984, + "special": false, + "text": " request" + } + ] + }, + "generated_text": "_uri\nTest request_uri\nTest request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.6015625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.671875, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 29918, + "logprob": -2.3847656, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -2.8144531, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -1.6396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.0527344, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.65478516, + "special": false, + "text": " request" + }, + { + "id": 29918, + "logprob": -0.056243896, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -0.016143799, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -0.8808594, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.75341797, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.2956543, + "special": false, + "text": " request" + } + ] + }, + "generated_text": "_uri\nTest request_uri\nTest request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.6015625, + "text": "Test" + }, + { + "id": 2009, + "logprob": -9.6640625, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 29918, + "logprob": -2.3769531, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -2.8183594, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -1.6396484, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -1.0546875, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.65478516, + "special": false, + "text": " request" + }, + { + "id": 29918, + "logprob": -0.05557251, + "special": false, + "text": "_" + }, + { + "id": 5338, + "logprob": -0.01612854, + "special": false, + "text": "uri" + }, + { + "id": 13, + "logprob": -0.8730469, + "special": false, + "text": "\n" + }, + { + "id": 3057, + "logprob": -0.7519531, + "special": false, + "text": "Test" + }, + { + "id": 2009, + "logprob": -0.29785156, + "special": false, + "text": " request" + } + ] + }, + "generated_text": "_uri\nTest request_uri\nTest request" + } +] diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json new file mode 100644 index 00000000..53055e42 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json @@ -0,0 +1,193 @@ +{ + "generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L", + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "seed": null, + "prefill": [ + { + "id": 589, + "text": "def", + "logprob": null + }, + { + "id": 3226, + "text": " ge", + "logprob": -9.0234375 + }, + { + "id": 21017, + "text": "ometric", + "logprob": -9.0859375 + }, + { + "id": 81, + "text": "_", + "logprob": -0.25878906 + }, + { + "id": 6009, + "text": "mean", + "logprob": -2.2109375 + }, + { + "id": 26, + "text": "(", + "logprob": -0.30371094 + }, + { + "id": 62, + "text": "L", + "logprob": -5.6054688 + }, + { + "id": 44, + "text": ":", + "logprob": -3.0722656 + }, + { + "id": 1682, + "text": " List", + "logprob": -0.6879883 + }, + { + "id": 77, + "text": "[", + "logprob": -0.38500977 + }, + { + "id": 1808, + "text": "float", + "logprob": -0.984375 + }, + { + "id": 10794, + "text": "]):", + "logprob": -2.5351562 + } + ], + "tokens": [ + { + "id": 284, + "text": "\n ", + "logprob": -1.1738281, + "special": false + }, + { + "id": 442, + "text": " return", + "logprob": -0.95947266, + "special": false + }, + { + "id": 3632, + "text": " sum", + "logprob": -1.4199219, + "special": false + }, + { + "id": 26, + "text": "(", + "logprob": -0.085876465, + "special": false + }, + { + "id": 62, + "text": "L", + "logprob": -0.09875488, + "special": false + }, + { + "id": 27, + "text": ")", + "logprob": -0.30517578, + "special": false + }, + { + "id": 517, + "text": " /", + "logprob": -0.42089844, + "special": false + }, + { + "id": 2069, + "text": " len", + "logprob": -0.042053223, + "special": false + }, + { + "id": 26, + "text": "(", + "logprob": -0.0011806488, + "special": false + }, + { + "id": 62, + "text": "L", + "logprob": -0.0005259514, + "special": false + }, + { + "id": 27, + "text": ")", + "logprob": -0.0017633438, + "special": false + }, + { + "id": 478, + "text": "\n\n", + "logprob": -0.69189453, + "special": false + }, + { + "id": 203, + "text": "\n", + "logprob": -0.041870117, + "special": false + }, + { + "id": 589, + "text": "def", + "logprob": -0.27856445, + "special": false + }, + { + "id": 3226, + "text": " ge", + "logprob": -1.7255859, + "special": false + }, + { + "id": 21017, + "text": "ometric", + "logprob": -0.011291504, + "special": false + }, + { + "id": 81, + "text": "_", + "logprob": -0.008430481, + "special": false + }, + { + "id": 6009, + "text": "mean", + "logprob": -0.025787354, + "special": false + }, + { + "id": 26, + "text": "(", + "logprob": -0.073913574, + "special": false + }, + { + "id": 62, + "text": "L", + "logprob": -0.09967041, + "special": false + } + ] + } +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json new file mode 100644 index 00000000..5598a2ad --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json @@ -0,0 +1,193 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 20, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -9.0234375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -9.09375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.25976562, + "text": "_" + }, + { + "id": 6009, + "logprob": -2.2148438, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.3010254, + "text": "(" + }, + { + "id": 62, + "logprob": -5.6757812, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0898438, + "text": ":" + }, + { + "id": 1682, + "logprob": -0.6791992, + "text": " List" + }, + { + "id": 77, + "logprob": -0.38891602, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.92041016, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.5390625, + "text": "]):" + } + ], + "seed": 0, + "tokens": [ + { + "id": 284, + "logprob": 0.0, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": 0.0, + "special": false, + "text": " return" + }, + { + "id": 11665, + "logprob": -1.6005859, + "special": false, + "text": " reduce" + }, + { + "id": 26, + "logprob": 0.0, + "special": false, + "text": "(" + }, + { + "id": 5962, + "logprob": 0.0, + "special": false, + "text": "lambda" + }, + { + "id": 816, + "logprob": 0.0, + "special": false, + "text": " x" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 533, + "logprob": 0.0, + "special": false, + "text": " y" + }, + { + "id": 44, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 816, + "logprob": 0.0, + "special": false, + "text": " x" + }, + { + "id": 319, + "logprob": 0.0, + "special": false, + "text": " *" + }, + { + "id": 533, + "logprob": 0.0, + "special": false, + "text": " y" + }, + { + "id": 30, + "logprob": 0.0, + "special": false, + "text": "," + }, + { + "id": 498, + "logprob": 0.0, + "special": false, + "text": " L" + }, + { + "id": 27, + "logprob": 0.0, + "special": false, + "text": ")" + }, + { + "id": 203, + "logprob": -0.11968994, + "special": false, + "text": "\n" + }, + { + "id": 203, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 589, + "logprob": 0.0, + "special": false, + "text": "def" + }, + { + "id": 3226, + "logprob": 0.0, + "special": false, + "text": " ge" + }, + { + "id": 21017, + "logprob": 0.0, + "special": false, + "text": "ometric" + } + ] + }, + "generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric" +} diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json new file mode 100644 index 00000000..5381ce5a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json @@ -0,0 +1,534 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -9.0234375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -9.0859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.25927734, + "text": "_" + }, + { + "id": 6009, + "logprob": -2.25, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.30126953, + "text": "(" + }, + { + "id": 62, + "logprob": -5.7539062, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0878906, + "text": ":" + }, + { + "id": 1682, + "logprob": -0.6845703, + "text": " List" + }, + { + "id": 77, + "logprob": -0.3918457, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.8798828, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.4980469, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -1.1533203, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": -0.91796875, + "special": false, + "text": " return" + }, + { + "id": 3632, + "logprob": -1.3291016, + "special": false, + "text": " sum" + }, + { + "id": 26, + "logprob": -0.08062744, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.097717285, + "special": false, + "text": "L" + }, + { + "id": 27, + "logprob": -0.29003906, + "special": false, + "text": ")" + }, + { + "id": 517, + "logprob": -0.34958984, + "special": false, + "text": " /" + }, + { + "id": 2069, + "logprob": -0.03829956, + "special": false, + "text": " len" + }, + { + "id": 26, + "logprob": -0.0011987686, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.00050878525, + "special": false, + "text": "L" + } + ] + }, + "generated_text": "\n return sum(L) / len(L" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -9.0234375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -9.0859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.25878906, + "text": "_" + }, + { + "id": 6009, + "logprob": -2.2109375, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.30371094, + "text": "(" + }, + { + "id": 62, + "logprob": -5.6054688, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0722656, + "text": ":" + }, + { + "id": 1682, + "logprob": -0.6879883, + "text": " List" + }, + { + "id": 77, + "logprob": -0.38500977, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.984375, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.5351562, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -1.1738281, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": -0.9584961, + "special": false, + "text": " return" + }, + { + "id": 3632, + "logprob": -1.4169922, + "special": false, + "text": " sum" + }, + { + "id": 26, + "logprob": -0.085876465, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.0982666, + "special": false, + "text": "L" + }, + { + "id": 27, + "logprob": -0.3022461, + "special": false, + "text": ")" + }, + { + "id": 517, + "logprob": -0.40504883, + "special": false, + "text": " /" + }, + { + "id": 2069, + "logprob": -0.041656494, + "special": false, + "text": " len" + }, + { + "id": 26, + "logprob": -0.0011844635, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.0005264282, + "special": false, + "text": "L" + } + ] + }, + "generated_text": "\n return sum(L) / len(L" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -9.0234375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -9.0859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.25927734, + "text": "_" + }, + { + "id": 6009, + "logprob": -2.25, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.30126953, + "text": "(" + }, + { + "id": 62, + "logprob": -5.7539062, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0878906, + "text": ":" + }, + { + "id": 1682, + "logprob": -0.6845703, + "text": " List" + }, + { + "id": 77, + "logprob": -0.3918457, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.8798828, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.4980469, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -1.1533203, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": -0.9165039, + "special": false, + "text": " return" + }, + { + "id": 3632, + "logprob": -1.328125, + "special": false, + "text": " sum" + }, + { + "id": 26, + "logprob": -0.07946777, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.09820557, + "special": false, + "text": "L" + }, + { + "id": 27, + "logprob": -0.28930664, + "special": false, + "text": ")" + }, + { + "id": 517, + "logprob": -0.34592773, + "special": false, + "text": " /" + }, + { + "id": 2069, + "logprob": -0.038330078, + "special": false, + "text": " len" + }, + { + "id": 26, + "logprob": -0.0011940002, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.00050878525, + "special": false, + "text": "L" + } + ] + }, + "generated_text": "\n return sum(L) / len(L" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 589, + "logprob": null, + "text": "def" + }, + { + "id": 3226, + "logprob": -9.0234375, + "text": " ge" + }, + { + "id": 21017, + "logprob": -9.0859375, + "text": "ometric" + }, + { + "id": 81, + "logprob": -0.25927734, + "text": "_" + }, + { + "id": 6009, + "logprob": -2.25, + "text": "mean" + }, + { + "id": 26, + "logprob": -0.30126953, + "text": "(" + }, + { + "id": 62, + "logprob": -5.7539062, + "text": "L" + }, + { + "id": 44, + "logprob": -3.0878906, + "text": ":" + }, + { + "id": 1682, + "logprob": -0.6845703, + "text": " List" + }, + { + "id": 77, + "logprob": -0.3918457, + "text": "[" + }, + { + "id": 1808, + "logprob": -0.8798828, + "text": "float" + }, + { + "id": 10794, + "logprob": -2.4980469, + "text": "]):" + } + ], + "seed": null, + "tokens": [ + { + "id": 284, + "logprob": -1.1533203, + "special": false, + "text": "\n " + }, + { + "id": 442, + "logprob": -0.91259766, + "special": false, + "text": " return" + }, + { + "id": 3632, + "logprob": -1.3251953, + "special": false, + "text": " sum" + }, + { + "id": 26, + "logprob": -0.08062744, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.09906006, + "special": false, + "text": "L" + }, + { + "id": 27, + "logprob": -0.28979492, + "special": false, + "text": ")" + }, + { + "id": 517, + "logprob": -0.35958984, + "special": false, + "text": " /" + }, + { + "id": 2069, + "logprob": -0.038604736, + "special": false, + "text": " len" + }, + { + "id": 26, + "logprob": -0.0011901855, + "special": false, + "text": "(" + }, + { + "id": 62, + "logprob": -0.0005078316, + "special": false, + "text": "L" + } + ] + }, + "generated_text": "\n return sum(L) / len(L" + } +] diff --git a/integration-tests/models/test_flash_llama_gptq.py b/integration-tests/models/test_flash_llama_gptq.py new file mode 100644 index 00000000..bc525f6d --- /dev/null +++ b/integration-tests/models/test_flash_llama_gptq.py @@ -0,0 +1,57 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_gptq_handle(launcher): + with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_gptq(flash_llama_gptq_handle): + await flash_llama_gptq_handle.health(300) + return flash_llama_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): + response = await flash_llama_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot): + responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py new file mode 100644 index 00000000..b6bed6a6 --- /dev/null +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -0,0 +1,49 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_starcoder_gptq_handle(launcher): + with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_starcoder_gptq(flash_starcoder_gptq_handle): + await flash_starcoder_gptq_handle.health(300) + return flash_starcoder_gptq_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot): + response = await flash_starcoder_gptq.generate( + "def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True, + ) + assert response.details.generated_tokens == 20 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot): + response = await flash_starcoder_gptq.generate( + "def geometric_mean(L: List[float]):", + max_new_tokens=20, + temperature=0.2, + top_p=0.95, + decoder_input_details=True, + seed=0, + ) + assert response.details.generated_tokens == 20 + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot): + responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot \ No newline at end of file diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cu b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu new file mode 100644 index 00000000..ee2cbee2 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cu @@ -0,0 +1,71 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh new file mode 100644 index 00000000..afb60a01 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_buffers.cuh @@ -0,0 +1,52 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_compat.cuh b/server/exllama_kernels/exllama_kernels/cuda_compat.cuh new file mode 100644 index 00000000..8dfa25de --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) +#if __CUDA_ARCH__ < 700 + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu new file mode 100644 index 00000000..c25b0206 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu @@ -0,0 +1,61 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh new file mode 100644 index 00000000..6571c17d --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu new file mode 100644 index 00000000..60dc4c9d --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu @@ -0,0 +1,252 @@ +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cuda_compat.cuh" +#include "../cuda_buffers.cuh" + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(acc.x, acc.y); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + +// const float alpha = 1.0f; +// const float beta = no_zero ? 1.0f : 0.0f; +// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, +// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +} diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh new file mode 100644 index 00000000..63611790 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh @@ -0,0 +1,37 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu new file mode 100644 index 00000000..f3d1564f --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu @@ -0,0 +1,217 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} \ No newline at end of file diff --git a/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh new file mode 100644 index 00000000..50cb72a4 --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/server/exllama_kernels/exllama_kernels/exllama_ext.cpp b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp new file mode 100644 index 00000000..b786988b --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/exllama_ext.cpp @@ -0,0 +1,249 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); +} diff --git a/server/exllama_kernels/exllama_kernels/matrix.cuh b/server/exllama_kernels/exllama_kernels/matrix.cuh new file mode 100644 index 00000000..2fd5ab0b --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/server/exllama_kernels/exllama_kernels/tuning.h b/server/exllama_kernels/exllama_kernels/tuning.h new file mode 100644 index 00000000..770ca46a --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/server/exllama_kernels/exllama_kernels/util.cuh b/server/exllama_kernels/exllama_kernels/util.cuh new file mode 100644 index 00000000..2839b10f --- /dev/null +++ b/server/exllama_kernels/exllama_kernels/util.cuh @@ -0,0 +1,29 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#define cudaUnspecified cudaErrorApiFailureBase + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/server/exllama_kernels/setup.py b/server/exllama_kernels/setup.py new file mode 100644 index 00000000..f06a72bd --- /dev/null +++ b/server/exllama_kernels/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name="exllama_kernels", + ext_modules=[ + CUDAExtension( + name="exllama_kernels", + sources=[ + "exllama_kernels/exllama_ext.cpp", + "exllama_kernels/cuda_buffers.cu", + "exllama_kernels/cuda_func/column_remap.cu", + "exllama_kernels/cuda_func/q4_matmul.cu", + "exllama_kernels/cuda_func/q4_matrix.cu" + ], + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 702bcf61..b6285856 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -383,7 +383,6 @@ class FlashLlamaLayer(nn.Module): class FlashLlamaModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.config = config process_group = weights.process_group self.tp_rank = process_group.rank() diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 6f5c60fc..04bd422f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -20,7 +20,6 @@ from text_generation_server.utils.layers import ( ) from safetensors import SafetensorError - def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): @@ -50,6 +49,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] qweight = torch.cat([q_tensor, kv_tensor], dim=1) + qweight = qweight.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.scales") shape = slice_.get_shape() @@ -60,6 +60,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] scales = torch.cat([q_tensor, kv_tensor], dim=1) + scales = scales.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") shape = slice_.get_shape() @@ -70,21 +71,15 @@ def _load_multi_mqa_gptq( q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] qzeros = torch.cat([q_tensor, kv_tensor], dim=1) + qzeros = qzeros.to(device=weights.device) g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") - try: - bits = weights.get_tensor("gptq_bits").item() - groupsize = weights.get_tensor("gptq_groupsize").item() - except SafetensorError as e: - try: - import os + g_idx = g_idx.to(device=weights.device) + bits, groupsize = weights._get_gptq_qparams() - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e - - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + from text_generation_server.utils.layers import HAS_EXLLAMA + use_exllama = HAS_EXLLAMA + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") @@ -97,6 +92,7 @@ def _load_multi_mqa_gptq( q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) + bias = bias.to(device=weights.device) return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) else: @@ -361,7 +357,6 @@ class Block(nn.Module): max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) - hidden_states = self.attn( hidden_states, cu_seqlen_prefill, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c0592cb0..547678a8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -6,7 +6,6 @@ import torch.distributed import numpy as np from dataclasses import dataclass -from loguru import logger from opentelemetry import trace from transformers import PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Union, Dict diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 3827197f..89e6e99b 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -3,14 +3,13 @@ import torch from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type -from transformers import PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PretrainedConfig from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.pb.generate_pb2 import InfoResponse B = TypeVar("B", bound=Batch) - class Model(ABC): def __init__( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e0efbcf5..b279426b 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): self.cache = cache @@ -140,6 +141,16 @@ def serve( logger.exception("Error when initializing model") raise + if quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.gptq.exllama import create_exllama_buffers + create_exllama_buffers() + except ImportError: + pass + server = aio.server( interceptors=[ ExceptionInterceptor(), diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py new file mode 100644 index 00000000..aba66796 --- /dev/null +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -0,0 +1,121 @@ + +import torch +from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params + +# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension +none_tensor = torch.empty((1, 1), device = "meta") + +def ext_make_q4(qweight, qzeros, scales, g_idx, device): + """Construct Q4Matrix, return handle""" + return make_q4(qweight, + qzeros, + scales, + g_idx if g_idx is not None else none_tensor, + device) + +def ext_q4_matmul(x, q4, q4_width): + """Matrix multiplication, returns x @ q4""" + outshape = x.shape[:-1] + (q4_width,) + x = x.view(-1, x.shape[-1]) + output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device) + + q4_matmul(x, q4, output) + + return output.view(outshape) + +MAX_DQ = 1 +MAX_INNER = 1 +ACT_ORDER = False +DEVICE = None + +TEMP_STATE = None +TEMP_DQ = None + +def create_exllama_buffers(): + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ + + if ACT_ORDER: + # TODO: this should be set to rust side `max_total_tokens`, but TGI + # does not offer an API to expose this variable to python, as this variable + # is handled by the client but it appears the model is initialized by the server. + # An alternative could be to initialize the buffers during warmup. + # Dummy + max_total_tokens = 2048 + else: + max_total_tokens = 1 + + # This temp_state buffer is required to reorder X in the act-order case. + temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE) + temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE) + + # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. + prepare_buffers(DEVICE, temp_state, temp_dq) + + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + TEMP_STATE, TEMP_DQ = temp_state, temp_dq + +class Ex4bitLinear: + """Linear layer implementation with per-group 4-bit quantization of the weights""" + def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize): + global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE + assert bits == 4 + + self.device = qweight.device + self.qweight = qweight + self.qzeros = qzeros + self.scales = scales + self.g_idx = g_idx.cpu() if g_idx is not None else None + self.bias = bias if bias is not None else None + + if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))): + self.empty_g_idx = True + self.g_idx = None + + assert self.device.type == "cuda" + assert self.device.index is not None + + self.q4 = ext_make_q4( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + self.device.index + ) + + self.height = qweight.shape[0] * 8 + self.width = qweight.shape[1] + + # Infer groupsize from height of qzeros + self.groupsize = None + if self.qzeros.shape[0] > 1: + self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0]) + + if self.groupsize is not None: + assert groupsize == self.groupsize + + # Handle act-order matrix + if self.g_idx is not None: + if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?") + self.act_order = True + else: + self.act_order = False + + DEVICE = self.qweight.device + + MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8) + + if self.act_order: + MAX_INNER = max(MAX_INNER, self.height, self.width) + + ACT_ORDER = True + + def forward(self, x): + out = ext_q4_matmul(x, self.q4, self.width) + + if self.bias is not None: + out.add_(self.bias) + return out diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4f65446e..4f280161 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,3 +1,4 @@ +import os import torch import torch.distributed @@ -16,7 +17,15 @@ except ImportError: from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear +HAS_EXLLAMA = True +if os.getenv("DISABLE_EXLLAMA") == "True": + HAS_EXLLAMA=False +try: + from text_generation_server.utils.gptq.exllama import Ex4bitLinear +except ImportError: + HAS_EXLLAMA = False +from typing import Optional # Monkey patching @classmethod @@ -144,21 +153,24 @@ def get_linear(weight, bias, quantize): linear.bias = nn.Parameter(bias) elif quantize == "gptq": try: - qweight, qzeros, scales, g_idx, bits, groupsize = weight + qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight except Exception: raise NotImplementedError( f"The passed weight is not `gptq` compatible, loader needs to be updated." ) - linear = QuantLinear( - qweight, - qzeros, - scales, - g_idx, - bias, - bits, - groupsize, - ) + if use_exllama: + linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize) + else: + linear = QuantLinear( + qweight, + qzeros, + scales, + g_idx, + bias, + bits, + groupsize, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 723d0558..4f284ea4 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,8 @@ from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Tuple from safetensors import safe_open, SafetensorError import torch +from loguru import logger class Weights: @@ -127,18 +128,8 @@ class Weights: torch.testing.assert_close(w2, w[0]) g_idx = w[0] - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - except (SafetensorError, RuntimeError) as e: - try: - import os - - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + bits, groupsize = self._get_gptq_qparams() + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -146,29 +137,74 @@ class Weights: def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": + use_exllama = True + bits, groupsize = self._get_gptq_qparams() + + if bits != 4: + use_exllama = False + + if self.process_group.size() > 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + if g_idx is not None: + if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all(): + # Exllama implementation does not support row tensor parallelism with act-order, as + # it would require to reorder input activations that are split unto several GPUs + use_exllama = False + try: qweight = self.get_sharded(f"{prefix}.qweight", dim=0) except RuntimeError: - raise RuntimeError( - "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" - ) - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`") + - try: - bits = self.get_tensor("gptq_bits").item() - groupsize = self.get_tensor("gptq_groupsize").item() - except (SafetensorError, RuntimeError) as e: - try: - import os + from text_generation_server.utils.layers import HAS_EXLLAMA + if use_exllama: + if not HAS_EXLLAMA: + logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True") + use_exllama = False + else: + logger.info("Using exllama kernels") - bits = int(os.getenv("GPTQ_BITS")) - groupsize = int(os.getenv("GPTQ_GROUPSIZE")) - except Exception: - raise e - weight = (qweight, qzeros, scales, g_idx, bits, groupsize) + if use_exllama: + if groupsize >= 0: + # Exllama reorders the weights in advance and the activations on the fly, thus + # the scales and zero-points do not need to be reordered. + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + else: + raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported") + # qzeros = self.get_tensor(f"{prefix}.qzeros") + # scales = self.get_tensor(f"{prefix}.scales") + + # For tp > 1, at this point we know we do not use act-order + if self.process_group.size() == 1: + g_idx = self.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + else: + # The triton kernel reorders the scales/zero points instead of the weight/activation. + # Thus, each rank needs the full qzeros/scales. + qzeros = self.get_tensor(f"{prefix}.qzeros") + scales = self.get_tensor(f"{prefix}.scales") + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1) return weight + + def _get_gptq_qparams(self) -> Tuple[int, int]: + try: + bits = self.get_tensor("gptq_bits").item() + groupsize = self.get_tensor("gptq_groupsize").item() + except (SafetensorError, RuntimeError) as e: + try: + import os + + bits = int(os.getenv("GPTQ_BITS")) + groupsize = int(os.getenv("GPTQ_GROUPSIZE")) + except Exception: + raise e + + return bits, groupsize