Just trying to get the integration tests to pass. # What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
This commit is contained in:
parent
bf94df3c71
commit
d5b5bc750f
13
Dockerfile
13
Dockerfile
|
@ -108,6 +108,17 @@ COPY server/Makefile-flash-att-v2 Makefile
|
|||
# Build specific version of flash attention v2
|
||||
RUN make build-flash-attention-v2
|
||||
|
||||
# Build Transformers exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
|
||||
|
@ -161,6 +172,8 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
|
|||
|
||||
# Copy build artifacts from custom kernels builder
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
# Copy build artifacts from exllama kernels builder
|
||||
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy builds artifacts from vllm builder
|
||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
|
3
Makefile
3
Makefile
|
@ -56,3 +56,6 @@ run-bloom:
|
|||
|
||||
run-bloom-quantize:
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||
|
||||
clean:
|
||||
rm -rf target aml
|
||||
|
|
|
@ -230,15 +230,16 @@ def launcher(event_loop):
|
|||
shard_uds_path,
|
||||
]
|
||||
|
||||
env = os.environ
|
||||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize:
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append("bitsandbytes")
|
||||
args.append(quantize)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
env = os.environ
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
if not use_flash_attention:
|
||||
|
@ -275,9 +276,9 @@ def launcher(event_loop):
|
|||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize:
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append("bitsandbytes")
|
||||
args.append(quantize)
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.59375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3867188,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8183594,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6367188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6542969,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056121826,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.01600647,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87939453,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7529297,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.2980957,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": -1.1640625,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 1454,
|
||||
"logprob": -0.07543945,
|
||||
"special": false,
|
||||
"text": "for"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 9342,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "comment"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 396,
|
||||
"logprob": -0.2956543,
|
||||
"special": false,
|
||||
"text": " #"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.52734375,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29900,
|
||||
"logprob": -0.6899414,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -1.5068359,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Test request-for-comment: #2014"
|
||||
}
|
|
@ -0,0 +1,354 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.671875,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3828125,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8105469,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0546875,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6513672,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056365967,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016082764,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87841797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7548828,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29711914,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3828125,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.828125,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6386719,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6542969,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.055877686,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016021729,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8769531,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7583008,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29833984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.671875,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3847656,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8144531,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.65478516,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056243896,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016143799,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8808594,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.75341797,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.2956543,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3769531,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8183594,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0546875,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.65478516,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.05557251,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.01612854,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8730469,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7519531,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29785156,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,193 @@
|
|||
{
|
||||
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"text": "def",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"text": " ge",
|
||||
"logprob": -9.0234375
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"text": "ometric",
|
||||
"logprob": -9.0859375
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"text": "_",
|
||||
"logprob": -0.25878906
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"text": "mean",
|
||||
"logprob": -2.2109375
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.30371094
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -5.6054688
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"text": ":",
|
||||
"logprob": -3.0722656
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"text": " List",
|
||||
"logprob": -0.6879883
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"text": "[",
|
||||
"logprob": -0.38500977
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"text": "float",
|
||||
"logprob": -0.984375
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"text": "]):",
|
||||
"logprob": -2.5351562
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"text": "\n ",
|
||||
"logprob": -1.1738281,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"text": " return",
|
||||
"logprob": -0.95947266,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"text": " sum",
|
||||
"logprob": -1.4199219,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.085876465,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.09875488,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"text": ")",
|
||||
"logprob": -0.30517578,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"text": " /",
|
||||
"logprob": -0.42089844,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"text": " len",
|
||||
"logprob": -0.042053223,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.0011806488,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.0005259514,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"text": ")",
|
||||
"logprob": -0.0017633438,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 478,
|
||||
"text": "\n\n",
|
||||
"logprob": -0.69189453,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"text": "\n",
|
||||
"logprob": -0.041870117,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"text": "def",
|
||||
"logprob": -0.27856445,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"text": " ge",
|
||||
"logprob": -1.7255859,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"text": "ometric",
|
||||
"logprob": -0.011291504,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"text": "_",
|
||||
"logprob": -0.008430481,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"text": "mean",
|
||||
"logprob": -0.025787354,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.073913574,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.09967041,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.09375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25976562,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.2148438,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.3010254,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.6757812,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0898438,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6791992,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.38891602,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.92041016,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.5390625,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 11665,
|
||||
"logprob": -1.6005859,
|
||||
"special": false,
|
||||
"text": " reduce"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 5962,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "lambda"
|
||||
},
|
||||
{
|
||||
"id": 816,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " x"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 533,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " y"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 816,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " x"
|
||||
},
|
||||
{
|
||||
"id": 319,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " *"
|
||||
},
|
||||
{
|
||||
"id": 533,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " y"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 498,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": -0.11968994,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
|
||||
}
|
|
@ -0,0 +1,534 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.91796875,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.3291016,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.08062744,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.097717285,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.29003906,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.34958984,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.03829956,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011987686,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.00050878525,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25878906,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.2109375,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30371094,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.6054688,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0722656,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6879883,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.38500977,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.984375,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.5351562,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1738281,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.9584961,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.4169922,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.085876465,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0982666,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.3022461,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.40504883,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.041656494,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011844635,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0005264282,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.9165039,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.328125,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.07946777,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.09820557,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.28930664,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.34592773,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.038330078,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011940002,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.00050878525,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.91259766,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.3251953,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.08062744,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.09906006,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.28979492,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.35958984,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.038604736,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011901855,
|
||||
"special": false,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0005078316,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,57 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_handle(launcher):
|
||||
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_gptq(flash_llama_gptq_handle):
|
||||
await flash_llama_gptq_handle.health(300)
|
||||
return flash_llama_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_starcoder_gptq_handle(launcher):
|
||||
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||
await flash_starcoder_gptq_handle.health(300)
|
||||
return flash_starcoder_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
|
||||
response = await flash_starcoder_gptq.generate(
|
||||
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot):
|
||||
response = await flash_starcoder_gptq.generate(
|
||||
"def geometric_mean(L: List[float]):",
|
||||
max_new_tokens=20,
|
||||
temperature=0.2,
|
||||
top_p=0.95,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -0,0 +1,71 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#define _cuda_buffers_cu
|
||||
#include "cuda_buffers.cuh"
|
||||
|
||||
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
|
||||
// __constant__ half2 q4_table[16][256];
|
||||
// half2 q4_table_host[16][256];
|
||||
// bool q4_table_init = false;
|
||||
|
||||
CudaBuffers::CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
temp_state(_temp_state),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(_device);
|
||||
|
||||
cudaStreamCreate(&alt_stream_1);
|
||||
cudaStreamCreate(&alt_stream_2);
|
||||
cudaStreamCreate(&alt_stream_3);
|
||||
cudaEventCreate(&alt_stream_1_done);
|
||||
cudaEventCreate(&alt_stream_2_done);
|
||||
cudaEventCreate(&alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers::~CudaBuffers()
|
||||
{
|
||||
cudaStreamDestroy(alt_stream_1);
|
||||
cudaStreamDestroy(alt_stream_2);
|
||||
cudaStreamDestroy(alt_stream_3);
|
||||
cudaEventDestroy(alt_stream_1_done);
|
||||
cudaEventDestroy(alt_stream_2_done);
|
||||
cudaEventDestroy(alt_stream_3_done);
|
||||
}
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index)
|
||||
{
|
||||
return g_buffers[device_index];
|
||||
}
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
)
|
||||
{
|
||||
CudaBuffers* buffers = new CudaBuffers
|
||||
(
|
||||
_device,
|
||||
_temp_state,
|
||||
_temp_dq
|
||||
);
|
||||
|
||||
g_buffers[_device] = buffers;
|
||||
}
|
||||
|
||||
void cleanup_buffers_cuda()
|
||||
{
|
||||
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
|
||||
{
|
||||
if (!g_buffers[i]) continue;
|
||||
delete g_buffers[i];
|
||||
g_buffers[i] = NULL;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_buffers_cuh
|
||||
#define _cuda_buffers_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
const int CUDA_MAX_DEVICES = 16;
|
||||
|
||||
// #ifndef _cuda_buffers_cu
|
||||
// extern __constant__ half2 q4_table[16][256];
|
||||
// #endif
|
||||
|
||||
class CudaBuffers
|
||||
{
|
||||
public:
|
||||
int device;
|
||||
|
||||
half* temp_state; // [max_hidden_rows * intermediate_size]
|
||||
half* temp_dq; // size of largest quant tensor * 8
|
||||
|
||||
cudaStream_t alt_stream_1;
|
||||
cudaStream_t alt_stream_2;
|
||||
cudaStream_t alt_stream_3;
|
||||
cudaEvent_t alt_stream_1_done;
|
||||
cudaEvent_t alt_stream_2_done;
|
||||
cudaEvent_t alt_stream_3_done;
|
||||
|
||||
CudaBuffers
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
~CudaBuffers();
|
||||
};
|
||||
|
||||
CudaBuffers* get_buffers(const int device_index);
|
||||
|
||||
void prepare_buffers_cuda
|
||||
(
|
||||
int _device,
|
||||
half* _temp_state,
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
void cleanup_buffers_cuda();
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _cuda_compat_cuh
|
||||
#define _cuda_compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if __CUDA_ARCH__ < 700
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
|
@ -0,0 +1,61 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
|
||||
const int SHUF_BLOCKSIZE_X = 256;
|
||||
const int SHUF_BLOCKSIZE_Y = 16;
|
||||
|
||||
__global__ void column_remap_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
half* __restrict__ x_new,
|
||||
const int x_width,
|
||||
const int x_height,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
|
||||
|
||||
int x_stride = x_width;
|
||||
int x_idx = x_row * x_stride + x_column;
|
||||
|
||||
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
|
||||
int x_idx_end = x_row_end * x_stride + x_column;
|
||||
|
||||
int s_column = x_map[x_column];
|
||||
int s_idx = x_row * x_stride + s_column;
|
||||
|
||||
while (x_idx < x_idx_end)
|
||||
{
|
||||
x_new[x_idx] = x[s_idx];
|
||||
x_idx += x_stride;
|
||||
s_idx += x_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap columns in x to correspond to sequential group index before matmul
|
||||
//
|
||||
// perform x -> seq_x such that seq_x @ seq_w == x @ w
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
|
||||
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
|
||||
1
|
||||
);
|
||||
|
||||
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _column_remap_cuh
|
||||
#define _column_remap_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
void column_remap_cuda
|
||||
(
|
||||
const half* x,
|
||||
half* x_new,
|
||||
const int x_height,
|
||||
const int x_width,
|
||||
const uint32_t* x_map
|
||||
);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,252 @@
|
|||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
#include "../cuda_compat.cuh"
|
||||
#include "../cuda_buffers.cuh"
|
||||
|
||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
half*,
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint32_t*,
|
||||
bool
|
||||
);
|
||||
|
||||
template<bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
__global__ void q4_matmul_kernel
|
||||
(
|
||||
const half* __restrict__ x,
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int dim,
|
||||
const int width,
|
||||
const int groupsize,
|
||||
const int block_size_z,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_half x_(x, height, dim);
|
||||
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
|
||||
MatrixView_q4_column w_(w, dim, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
|
||||
// Zero output
|
||||
|
||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
half2 acc = {};
|
||||
half acc_h = {};
|
||||
|
||||
if constexpr (use_groupsize)
|
||||
{
|
||||
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
|
||||
// could be slightly faster
|
||||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
{
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add to block result
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(acc.x, acc.y);
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
{
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
|
||||
}
|
||||
}
|
||||
|
||||
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
|
||||
{
|
||||
// <bool use_half2, bool use_groupsize, bool use_x_map>
|
||||
if (tuningParams->matmul_no_half2) {
|
||||
if (block_size_z % groupsize == 0) {
|
||||
if (x_map) return q4_matmul_kernel<false, true, true >;
|
||||
else return q4_matmul_kernel<false, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<false, false, true >;
|
||||
else return q4_matmul_kernel<false, false, false>;
|
||||
}
|
||||
} else {
|
||||
if (block_size_z % groupsize == 0)
|
||||
{
|
||||
if (x_map) return q4_matmul_kernel<true, true, true >;
|
||||
else return q4_matmul_kernel<true, true, false>;
|
||||
} else {
|
||||
if (x_map) return q4_matmul_kernel<true, false, true >;
|
||||
else return q4_matmul_kernel<true, false, false>;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute y = x @ w
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero,
|
||||
cudaStream_t alt_stream
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
|
||||
uint32_t* x_map = w->cuda_x_map;
|
||||
const half* x_mapped = x;
|
||||
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
|
||||
{
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
x_map = NULL;
|
||||
}
|
||||
|
||||
int block_size_z;
|
||||
if (w->width == 4096) block_size_z = 384; // 7B
|
||||
else if (w->width == 11008) block_size_z = 256;
|
||||
else if (w->width == 5120) block_size_z = 384; // 13B
|
||||
else if (w->width == 13824) block_size_z = 256;
|
||||
else if (w->width == 6656) block_size_z = 256; // 33B
|
||||
else if (w->width == 17920) block_size_z = 128;
|
||||
else block_size_z = 256;
|
||||
|
||||
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
|
||||
|
||||
dim3 threads(THREADS_X, THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height + threads.y - 1) / threads.y,
|
||||
(dim + block_size_z - 1) / block_size_z
|
||||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero
|
||||
)
|
||||
{
|
||||
int height = x_height;
|
||||
int dim = w->height;
|
||||
int width = w->width;
|
||||
|
||||
cudaSetDevice(w->device);
|
||||
CudaBuffers* buffers = get_buffers(w->device);
|
||||
|
||||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
||||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
|
||||
// const float alpha = 1.0f;
|
||||
// const float beta = no_zero ? 1.0f : 0.0f;
|
||||
// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matmul_cuh
|
||||
#define _q4_matmul_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include "../tuning.h"
|
||||
|
||||
void q4_matmul_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
const Q4Matrix* w,
|
||||
half* out,
|
||||
bool no_zero = false,
|
||||
cudaStream_t alt_stream = NULL
|
||||
);
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
(
|
||||
ExLlamaTuning* tuningParams,
|
||||
const half* x,
|
||||
const int x_height,
|
||||
Q4Matrix* w,
|
||||
half* out,
|
||||
const cublasHandle_t handle,
|
||||
bool no_zero = false
|
||||
);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,217 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matrix.cuh"
|
||||
#include <vector>
|
||||
#include "../util.cuh"
|
||||
#include "../matrix.cuh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
const int UNSHUF_BLOCKSIZE_X = 64;
|
||||
|
||||
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
|
||||
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
|
||||
|
||||
vector<Q4Matrix*> g_q4_matrices;
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m)
|
||||
{
|
||||
g_q4_matrices.push_back(m);
|
||||
}
|
||||
|
||||
void g_q4_free_matrices()
|
||||
{
|
||||
for (const auto& m : g_q4_matrices) delete m;
|
||||
g_q4_matrices.clear();
|
||||
}
|
||||
|
||||
Q4Matrix::Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
) :
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
device(_device)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_qweight = _qweight;
|
||||
cuda_qzeros = _qzeros;
|
||||
cuda_scales = _scales;
|
||||
|
||||
groupsize = height / groups;
|
||||
|
||||
if (_g_idx) make_sequential(_g_idx);
|
||||
}
|
||||
|
||||
Q4Matrix::~Q4Matrix()
|
||||
{
|
||||
}
|
||||
|
||||
// Make sequential
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint32_t* __restrict__ x_map,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int x_map_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = x_map[x_map_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
|
||||
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
|
||||
|
||||
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out, // (y)
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int height,
|
||||
const int width,
|
||||
const int groupsize
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, height / groupsize, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
|
||||
|
||||
// Groupsize version
|
||||
|
||||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
void Q4Matrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
|
||||
|
||||
dim3 blocks
|
||||
(
|
||||
(width + threads.x - 1) / threads.x,
|
||||
(height / 8 + threads.y - 1) / threads.y,
|
||||
1
|
||||
);
|
||||
|
||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _q4_matrix_cuh
|
||||
#define _q4_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
|
||||
class Q4Matrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
uint32_t* cuda_qweight = NULL;
|
||||
uint32_t* cuda_qzeros = NULL;
|
||||
half* cuda_scales = NULL;
|
||||
uint32_t* cuda_x_map = NULL;
|
||||
|
||||
Q4Matrix
|
||||
(
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _qweight,
|
||||
uint32_t* _qzeros,
|
||||
half* _scales,
|
||||
uint32_t* _g_idx,
|
||||
|
||||
const int _device
|
||||
);
|
||||
|
||||
~Q4Matrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
|
||||
private:
|
||||
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
};
|
||||
|
||||
void g_q4_keep_matrix(Q4Matrix* m);
|
||||
void g_q4_free_matrices();
|
||||
|
||||
#endif
|
|
@ -0,0 +1,249 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include "util.cuh"
|
||||
#include "tuning.h"
|
||||
#include "cuda_buffers.cuh"
|
||||
#include "cuda_func/q4_matrix.cuh"
|
||||
#include "cuda_func/q4_matmul.cuh"
|
||||
#include "cuda_func/column_remap.cuh"
|
||||
|
||||
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
|
||||
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
|
||||
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
|
||||
|
||||
void check_cuda(cudaError_t ret)
|
||||
{
|
||||
switch (ret)
|
||||
{
|
||||
case cudaSuccess:
|
||||
break;
|
||||
|
||||
case cudaUnspecified:
|
||||
printf(" **** Unspecified error\n");
|
||||
TORCH_CHECK(false, "CUDA error");
|
||||
break;
|
||||
|
||||
default:
|
||||
printf(" **** CUDA error\n"); \
|
||||
printf(" **** %s\n", cudaGetErrorString(ret)); \
|
||||
TORCH_CHECK(false, "CUDA error"); \
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define STRINGIFY_(__x) #__x
|
||||
#define STRINGIFY(__x) STRINGIFY_(__x)
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
|
||||
|
||||
#define TORCH_CHECK_DEVICE_INDEX(__index) \
|
||||
do { \
|
||||
TORCH_CHECK(__index >= 0, "no device index"); \
|
||||
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
|
||||
} while(0)
|
||||
|
||||
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
|
||||
do { \
|
||||
TORCH_CHECK_DTYPE(__w, kInt); \
|
||||
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
|
||||
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
|
||||
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
|
||||
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
|
||||
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
|
||||
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
|
||||
} while(0)
|
||||
|
||||
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
|
||||
{
|
||||
int groupsize = w.size(0) * 8 / w_zeros.size(0);
|
||||
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
|
||||
return groupsize;
|
||||
}
|
||||
|
||||
|
||||
// Tuning parameters
|
||||
|
||||
ExLlamaTuning tuningParams;
|
||||
|
||||
void set_tuning_params
|
||||
(
|
||||
int matmul_recons_thd,
|
||||
bool matmul_fused_remap,
|
||||
bool matmul_no_half2
|
||||
)
|
||||
{
|
||||
tuningParams.matmul_recons_thd = matmul_recons_thd;
|
||||
tuningParams.matmul_fused_remap = matmul_fused_remap;
|
||||
tuningParams.matmul_no_half2 = matmul_no_half2;
|
||||
}
|
||||
|
||||
|
||||
// Release all unmanaged objects allocated by the extension
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
cleanup_buffers_cuda();
|
||||
g_q4_free_matrices();
|
||||
}
|
||||
|
||||
|
||||
// Prepare buffers for forward pass
|
||||
|
||||
void prepare_buffers
|
||||
(
|
||||
torch::Device device,
|
||||
torch::Tensor temp_state,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
int device_index = device.index();
|
||||
TORCH_CHECK_DEVICE_INDEX(device_index);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
||||
|
||||
prepare_buffers_cuda
|
||||
(
|
||||
device_index,
|
||||
(half*) temp_state.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Create Q4Matrix, return handle
|
||||
|
||||
uintptr_t make_q4
|
||||
(
|
||||
torch::Tensor qweight,
|
||||
torch::Tensor qzeros,
|
||||
torch::Tensor scales,
|
||||
torch::Tensor g_idx,
|
||||
int device
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(qweight, kInt);
|
||||
TORCH_CHECK_DTYPE(qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE(scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
|
||||
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
|
||||
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
|
||||
|
||||
int width = qweight.size(1);
|
||||
int height = qweight.size(0) * 8;
|
||||
int groups = qzeros.size(0);
|
||||
|
||||
Q4Matrix* m = new Q4Matrix
|
||||
(
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
|
||||
(uint32_t*) qweight.data_ptr(),
|
||||
(uint32_t*) qzeros.data_ptr(),
|
||||
(half*) scales.data_ptr(),
|
||||
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
|
||||
|
||||
device
|
||||
);
|
||||
|
||||
g_q4_keep_matrix(m);
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
|
||||
// Matmul half @ quant -> half
|
||||
|
||||
void q4_matmul
|
||||
(
|
||||
torch::Tensor x,
|
||||
uintptr_t w,
|
||||
torch::Tensor out
|
||||
)
|
||||
{
|
||||
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(out, kHalf);
|
||||
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
|
||||
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
int x_height = x.size(0);
|
||||
|
||||
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
|
||||
{
|
||||
q4_matmul_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr()
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
q4_matmul_recons_cuda
|
||||
(
|
||||
&tuningParams,
|
||||
(half*) x.data_ptr(),
|
||||
x_height,
|
||||
wm,
|
||||
(half*) out.data_ptr(),
|
||||
at::cuda::getCurrentCUDABlasHandle()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remap columns in half tensor
|
||||
|
||||
void column_remap
|
||||
(
|
||||
torch::Tensor x,
|
||||
torch::Tensor x_new,
|
||||
torch::Tensor x_map
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(x, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_new, kHalf);
|
||||
TORCH_CHECK_DTYPE(x_map, kInt);
|
||||
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
|
||||
|
||||
int height = x.size(0);
|
||||
int width = x.size(1);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
column_remap_cuda
|
||||
(
|
||||
(half*) x.data_ptr(),
|
||||
(half*) x_new.data_ptr(),
|
||||
height,
|
||||
width,
|
||||
(uint32_t*) x_map.data_ptr()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
|
||||
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
|
||||
m.def("cleanup", &cleanup, "cleanup");
|
||||
m.def("make_q4", &make_q4, "make_q4");
|
||||
m.def("q4_matmul", &q4_matmul, "q4_matmul");
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _matrix_cuh
|
||||
#define _matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
|
||||
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
|
||||
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
|
||||
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
|
||||
|
||||
half2 tmp = __hmul2(*h_ptr++, v_01);
|
||||
tmp = __hfma2(*h_ptr++, v_23, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_45, tmp);
|
||||
tmp = __hfma2(*h_ptr++, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(*h_ptr++, v_0);
|
||||
tmp = __hfma(*h_ptr++, v_1, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_2, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_3, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_4, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_5, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_6, tmp);
|
||||
tmp = __hfma(*h_ptr++, v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,13 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _tuning_h
|
||||
#define _tuning_h
|
||||
|
||||
struct ExLlamaTuning
|
||||
{
|
||||
int matmul_recons_thd;
|
||||
bool matmul_fused_remap;
|
||||
bool matmul_no_half2;
|
||||
};
|
||||
|
||||
#endif
|
|
@ -0,0 +1,29 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#ifndef _util_cuh
|
||||
#define _util_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define cudaUnspecified cudaErrorApiFailureBase
|
||||
|
||||
// React to failure on return code != cudaSuccess
|
||||
|
||||
#define _cuda_check(fn) \
|
||||
do { \
|
||||
{_cuda_err = fn;} \
|
||||
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
|
||||
} while(false)
|
||||
|
||||
// React to failure on return code == 0
|
||||
|
||||
#define _alloc_check(fn) \
|
||||
do { \
|
||||
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
|
||||
else _cuda_err = cudaSuccess; \
|
||||
} while(false)
|
||||
|
||||
#endif
|
|
@ -0,0 +1,19 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name="exllama_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="exllama_kernels",
|
||||
sources=[
|
||||
"exllama_kernels/exllama_ext.cpp",
|
||||
"exllama_kernels/cuda_buffers.cu",
|
||||
"exllama_kernels/cuda_func/column_remap.cu",
|
||||
"exllama_kernels/cuda_func/q4_matmul.cu",
|
||||
"exllama_kernels/cuda_func/q4_matrix.cu"
|
||||
],
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
|
@ -383,7 +383,6 @@ class FlashLlamaLayer(nn.Module):
|
|||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
|
|
|
@ -20,7 +20,6 @@ from text_generation_server.utils.layers import (
|
|||
)
|
||||
from safetensors import SafetensorError
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
|
@ -50,6 +49,7 @@ def _load_multi_mqa_gptq(
|
|||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qweight = qweight.to(device=weights.device)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
|
||||
shape = slice_.get_shape()
|
||||
|
@ -60,6 +60,7 @@ def _load_multi_mqa_gptq(
|
|||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
scales = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
scales = scales.to(device=weights.device)
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
|
||||
shape = slice_.get_shape()
|
||||
|
@ -70,21 +71,15 @@ def _load_multi_mqa_gptq(
|
|||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
|
||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
try:
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
bits, groupsize = weights._get_gptq_qparams()
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
use_exllama = HAS_EXLLAMA
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
|
||||
if bias:
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
|
@ -97,6 +92,7 @@ def _load_multi_mqa_gptq(
|
|||
q_tensor = slice_[start:stop]
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
else:
|
||||
|
@ -361,7 +357,6 @@ class Block(nn.Module):
|
|||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
|
||||
hidden_states = self.attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch.distributed
|
|||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
|
|
@ -3,14 +3,13 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
|||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
|
||||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
self.cache = cache
|
||||
|
@ -140,6 +141,16 @@ def serve(
|
|||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
|
||||
import torch
|
||||
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device = "meta")
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx if g_idx is not None else none_tensor,
|
||||
device)
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
MAX_DQ = 1
|
||||
MAX_INNER = 1
|
||||
ACT_ORDER = False
|
||||
DEVICE = None
|
||||
|
||||
TEMP_STATE = None
|
||||
TEMP_DQ = None
|
||||
|
||||
def create_exllama_buffers():
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
if ACT_ORDER:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
# Dummy
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
|
||||
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||
|
||||
class Ex4bitLinear:
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
||||
assert bits == 4
|
||||
|
||||
self.device = qweight.device
|
||||
self.qweight = qweight
|
||||
self.qzeros = qzeros
|
||||
self.scales = scales
|
||||
self.g_idx = g_idx.cpu() if g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
|
||||
self.empty_g_idx = True
|
||||
self.g_idx = None
|
||||
|
||||
assert self.device.type == "cuda"
|
||||
assert self.device.index is not None
|
||||
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
self.scales,
|
||||
self.g_idx,
|
||||
self.device.index
|
||||
)
|
||||
|
||||
self.height = qweight.shape[0] * 8
|
||||
self.width = qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
|
||||
self.act_order = True
|
||||
else:
|
||||
self.act_order = False
|
||||
|
||||
DEVICE = self.qweight.device
|
||||
|
||||
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
|
||||
|
||||
if self.act_order:
|
||||
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
||||
|
||||
ACT_ORDER = True
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -16,7 +17,15 @@ except ImportError:
|
|||
from accelerate import init_empty_weights
|
||||
|
||||
from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||
HAS_EXLLAMA = True
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA=False
|
||||
try:
|
||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||
except ImportError:
|
||||
HAS_EXLLAMA = False
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
|
@ -144,21 +153,24 @@ def get_linear(weight, bias, quantize):
|
|||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "gptq":
|
||||
try:
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize = weight
|
||||
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
||||
except Exception:
|
||||
raise NotImplementedError(
|
||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||
)
|
||||
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
if use_exllama:
|
||||
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
else:
|
||||
linear = QuantLinear(
|
||||
qweight,
|
||||
qzeros,
|
||||
scales,
|
||||
g_idx,
|
||||
bias,
|
||||
bits,
|
||||
groupsize,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class Weights:
|
||||
|
@ -127,18 +128,8 @@ class Weights:
|
|||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
@ -146,29 +137,74 @@ class Weights:
|
|||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize = self._get_gptq_qparams()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
|
||||
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
|
||||
use_exllama = False
|
||||
else:
|
||||
logger.info("Using exllama kernels")
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
|
||||
|
||||
if use_exllama:
|
||||
if groupsize >= 0:
|
||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
||||
# the scales and zero-points do not need to be reordered.
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
|
||||
# qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
# scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
# For tp > 1, at this point we know we do not use act-order
|
||||
if self.process_group.size() == 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
else:
|
||||
g_idx = None
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
|
|
Loading…
Reference in New Issue