feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.


# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2023-07-21 10:59:00 +02:00 committed by GitHub
parent bf94df3c71
commit d5b5bc750f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 3232 additions and 63 deletions

View File

@ -108,6 +108,17 @@ COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2 # Build specific version of flash attention v2
RUN make build-flash-attention-v2 RUN make build-flash-attention-v2
# Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
@ -161,6 +172,8 @@ COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

View File

@ -56,3 +56,6 @@ run-bloom:
run-bloom-quantize: run-bloom-quantize:
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080 text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
clean:
rm -rf target aml

View File

@ -230,15 +230,16 @@ def launcher(event_loop):
shard_uds_path, shard_uds_path,
] ]
env = os.environ
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize: if quantize is not None:
args.append("--quantize") args.append("--quantize")
args.append("bitsandbytes") args.append(quantize)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention: if not use_flash_attention:
@ -275,9 +276,9 @@ def launcher(event_loop):
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize: if quantize is not None:
args.append("--quantize") args.append("--quantize")
args.append("bitsandbytes") args.append(quantize)
if trust_remote_code: if trust_remote_code:
args.append("--trust-remote-code") args.append("--trust-remote-code")

View File

@ -0,0 +1,88 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.59375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3867188,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8183594,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6367188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6542969,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056121826,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.01600647,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.87939453,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7529297,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.2980957,
"special": false,
"text": " request"
}
]
},
"generated_text": "_uri\nTest request_uri\nTest request"
}

View File

@ -0,0 +1,88 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 29899,
"logprob": -1.1640625,
"special": false,
"text": "-"
},
{
"id": 1454,
"logprob": -0.07543945,
"special": false,
"text": "for"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 9342,
"logprob": 0.0,
"special": false,
"text": "comment"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 396,
"logprob": -0.2956543,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.52734375,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.6899414,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 29946,
"logprob": -1.5068359,
"special": false,
"text": "4"
}
]
},
"generated_text": "Test request-for-comment: #2014"
}

View File

@ -0,0 +1,354 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.671875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3828125,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8105469,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0546875,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6513672,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056365967,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016082764,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.87841797,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7548828,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29711914,
"special": false,
"text": " request"
}
]
},
"generated_text": "_uri\nTest request_uri\nTest request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3828125,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.828125,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6386719,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6542969,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.055877686,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016021729,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8769531,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7583008,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29833984,
"special": false,
"text": " request"
}
]
},
"generated_text": "_uri\nTest request_uri\nTest request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.671875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3847656,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8144531,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.65478516,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056243896,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016143799,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8808594,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.75341797,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.2956543,
"special": false,
"text": " request"
}
]
},
"generated_text": "_uri\nTest request_uri\nTest request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.6015625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3769531,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8183594,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0546875,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.65478516,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.05557251,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.01612854,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8730469,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7519531,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29785156,
"special": false,
"text": " request"
}
]
},
"generated_text": "_uri\nTest request_uri\nTest request"
}
]

View File

@ -0,0 +1,193 @@
{
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 589,
"text": "def",
"logprob": null
},
{
"id": 3226,
"text": " ge",
"logprob": -9.0234375
},
{
"id": 21017,
"text": "ometric",
"logprob": -9.0859375
},
{
"id": 81,
"text": "_",
"logprob": -0.25878906
},
{
"id": 6009,
"text": "mean",
"logprob": -2.2109375
},
{
"id": 26,
"text": "(",
"logprob": -0.30371094
},
{
"id": 62,
"text": "L",
"logprob": -5.6054688
},
{
"id": 44,
"text": ":",
"logprob": -3.0722656
},
{
"id": 1682,
"text": " List",
"logprob": -0.6879883
},
{
"id": 77,
"text": "[",
"logprob": -0.38500977
},
{
"id": 1808,
"text": "float",
"logprob": -0.984375
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5351562
}
],
"tokens": [
{
"id": 284,
"text": "\n ",
"logprob": -1.1738281,
"special": false
},
{
"id": 442,
"text": " return",
"logprob": -0.95947266,
"special": false
},
{
"id": 3632,
"text": " sum",
"logprob": -1.4199219,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.085876465,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.09875488,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.30517578,
"special": false
},
{
"id": 517,
"text": " /",
"logprob": -0.42089844,
"special": false
},
{
"id": 2069,
"text": " len",
"logprob": -0.042053223,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.0011806488,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.0005259514,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.0017633438,
"special": false
},
{
"id": 478,
"text": "\n\n",
"logprob": -0.69189453,
"special": false
},
{
"id": 203,
"text": "\n",
"logprob": -0.041870117,
"special": false
},
{
"id": 589,
"text": "def",
"logprob": -0.27856445,
"special": false
},
{
"id": 3226,
"text": " ge",
"logprob": -1.7255859,
"special": false
},
{
"id": 21017,
"text": "ometric",
"logprob": -0.011291504,
"special": false
},
{
"id": 81,
"text": "_",
"logprob": -0.008430481,
"special": false
},
{
"id": 6009,
"text": "mean",
"logprob": -0.025787354,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.073913574,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.09967041,
"special": false
}
]
}
}

View File

@ -0,0 +1,193 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.09375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25976562,
"text": "_"
},
{
"id": 6009,
"logprob": -2.2148438,
"text": "mean"
},
{
"id": 26,
"logprob": -0.3010254,
"text": "("
},
{
"id": 62,
"logprob": -5.6757812,
"text": "L"
},
{
"id": 44,
"logprob": -3.0898438,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6791992,
"text": " List"
},
{
"id": 77,
"logprob": -0.38891602,
"text": "["
},
{
"id": 1808,
"logprob": -0.92041016,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5390625,
"text": "]):"
}
],
"seed": 0,
"tokens": [
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": 0.0,
"special": false,
"text": " return"
},
{
"id": 11665,
"logprob": -1.6005859,
"special": false,
"text": " reduce"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 5962,
"logprob": 0.0,
"special": false,
"text": "lambda"
},
{
"id": 816,
"logprob": 0.0,
"special": false,
"text": " x"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 533,
"logprob": 0.0,
"special": false,
"text": " y"
},
{
"id": 44,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 816,
"logprob": 0.0,
"special": false,
"text": " x"
},
{
"id": 319,
"logprob": 0.0,
"special": false,
"text": " *"
},
{
"id": 533,
"logprob": 0.0,
"special": false,
"text": " y"
},
{
"id": 30,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 498,
"logprob": 0.0,
"special": false,
"text": " L"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 203,
"logprob": -0.11968994,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 3226,
"logprob": 0.0,
"special": false,
"text": " ge"
},
{
"id": 21017,
"logprob": 0.0,
"special": false,
"text": "ometric"
}
]
},
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
}

View File

@ -0,0 +1,534 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91796875,
"special": false,
"text": " return"
},
{
"id": 3632,
"logprob": -1.3291016,
"special": false,
"text": " sum"
},
{
"id": 26,
"logprob": -0.08062744,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.097717285,
"special": false,
"text": "L"
},
{
"id": 27,
"logprob": -0.29003906,
"special": false,
"text": ")"
},
{
"id": 517,
"logprob": -0.34958984,
"special": false,
"text": " /"
},
{
"id": 2069,
"logprob": -0.03829956,
"special": false,
"text": " len"
},
{
"id": 26,
"logprob": -0.0011987686,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.00050878525,
"special": false,
"text": "L"
}
]
},
"generated_text": "\n return sum(L) / len(L"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25878906,
"text": "_"
},
{
"id": 6009,
"logprob": -2.2109375,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30371094,
"text": "("
},
{
"id": 62,
"logprob": -5.6054688,
"text": "L"
},
{
"id": 44,
"logprob": -3.0722656,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6879883,
"text": " List"
},
{
"id": 77,
"logprob": -0.38500977,
"text": "["
},
{
"id": 1808,
"logprob": -0.984375,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5351562,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -1.1738281,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9584961,
"special": false,
"text": " return"
},
{
"id": 3632,
"logprob": -1.4169922,
"special": false,
"text": " sum"
},
{
"id": 26,
"logprob": -0.085876465,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.0982666,
"special": false,
"text": "L"
},
{
"id": 27,
"logprob": -0.3022461,
"special": false,
"text": ")"
},
{
"id": 517,
"logprob": -0.40504883,
"special": false,
"text": " /"
},
{
"id": 2069,
"logprob": -0.041656494,
"special": false,
"text": " len"
},
{
"id": 26,
"logprob": -0.0011844635,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.0005264282,
"special": false,
"text": "L"
}
]
},
"generated_text": "\n return sum(L) / len(L"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9165039,
"special": false,
"text": " return"
},
{
"id": 3632,
"logprob": -1.328125,
"special": false,
"text": " sum"
},
{
"id": 26,
"logprob": -0.07946777,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.09820557,
"special": false,
"text": "L"
},
{
"id": 27,
"logprob": -0.28930664,
"special": false,
"text": ")"
},
{
"id": 517,
"logprob": -0.34592773,
"special": false,
"text": " /"
},
{
"id": 2069,
"logprob": -0.038330078,
"special": false,
"text": " len"
},
{
"id": 26,
"logprob": -0.0011940002,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.00050878525,
"special": false,
"text": "L"
}
]
},
"generated_text": "\n return sum(L) / len(L"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 589,
"logprob": null,
"text": "def"
},
{
"id": 3226,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91259766,
"special": false,
"text": " return"
},
{
"id": 3632,
"logprob": -1.3251953,
"special": false,
"text": " sum"
},
{
"id": 26,
"logprob": -0.08062744,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.09906006,
"special": false,
"text": "L"
},
{
"id": 27,
"logprob": -0.28979492,
"special": false,
"text": ")"
},
{
"id": 517,
"logprob": -0.35958984,
"special": false,
"text": " /"
},
{
"id": 2069,
"logprob": -0.038604736,
"special": false,
"text": " len"
},
{
"id": 26,
"logprob": -0.0011901855,
"special": false,
"text": "("
},
{
"id": 62,
"logprob": -0.0005078316,
"special": false,
"text": "L"
}
]
},
"generated_text": "\n return sum(L) / len(L"
}
]

View File

@ -0,0 +1,57 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq(flash_llama_gptq_handle):
await flash_llama_gptq_handle.health(300)
return flash_llama_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_llama_gptq, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,49 @@
import pytest
@pytest.fixture(scope="module")
def flash_starcoder_gptq_handle(launcher):
with launcher("Narsil/starcoder-gptq", num_shard=2, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
await flash_starcoder_gptq_handle.health(300)
return flash_starcoder_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, response_snapshot):
response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):", max_new_tokens=20, decoder_input_details=True,
)
assert response.details.generated_tokens == 20
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_default_params(flash_starcoder_gptq, response_snapshot):
response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):",
max_new_tokens=20,
temperature=0.2,
top_p=0.95,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 20
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_load(flash_starcoder_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_starcoder_gptq, "def geometric_mean(L: List[float]):", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -0,0 +1,71 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#define _cuda_buffers_cu
#include "cuda_buffers.cuh"
CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;
CudaBuffers::CudaBuffers
(
int _device,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);
cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}
CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}
CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}
void prepare_buffers_cuda
(
int _device,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state,
_temp_dq
);
g_buffers[_device] = buffers;
}
void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}

View File

@ -0,0 +1,52 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
const int CUDA_MAX_DEVICES = 16;
// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif
class CudaBuffers
{
public:
int device;
half* temp_state; // [max_hidden_rows * intermediate_size]
half* temp_dq; // size of largest quant tensor * 8
cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;
CudaBuffers
(
int _device,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};
CudaBuffers* get_buffers(const int device_index);
void prepare_buffers_cuda
(
int _device,
half* _temp_state,
half* _temp_dq
);
void cleanup_buffers_cuda();
#endif

View File

@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ < 700
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
#endif

View File

@ -0,0 +1,61 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "column_remap.cuh"
#include "../util.cuh"
const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;
__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;
int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;
int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;
while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}
// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);
column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}

View File

@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _column_remap_cuh
#define _column_remap_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);
#endif

View File

@ -0,0 +1,252 @@
#include "q4_matmul.cuh"
#include "column_remap.cuh"
#include "../util.cuh"
#include "../matrix.cuh"
#include "../cuda_compat.cuh"
#include "../cuda_buffers.cuh"
const int THREADS_X = 32; // Block size and thread count along columns in w and out
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
typedef void (*fp_q4_matmul_kernel)
(
const half*,
const uint32_t*,
half*,
const half*,
const uint32_t*,
const int,
const int,
const int,
const int,
const int,
const uint32_t*,
bool
);
template<bool use_half2, bool use_groupsize, bool use_x_map>
__global__ void q4_matmul_kernel
(
const half* __restrict__ x,
const uint32_t* __restrict__ w,
half* __restrict__ out,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int dim,
const int width,
const int groupsize,
const int block_size_z,
const uint32_t* __restrict__ x_map,
bool no_zero
)
{
// Start of block
int x_column = block_size_z * blockIdx.z;
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
int iterations = (x_column_end - x_column) / 8;
// Views
MatrixView_half x_(x, height, dim);
MatrixView_half w_scales_(w_scales, dim / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);
MatrixView_q4_column w_(w, dim, width);
MatrixView_half_rw out_(out, height, width);
// Zero output
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
{
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
__syncthreads();
}
// Loop over part of x row (and w column)
half2 acc = {};
half acc_h = {};
if constexpr (use_groupsize)
{
// For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this
// could be slightly faster
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
{
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
}
}
}
else
{
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
for (int k = x_column; k < x_column + iterations * 8; k += 8)
{
if constexpr (use_half2)
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
else
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
}
}
}
// Add to block result
if constexpr (use_half2)
{
half result = __hadd(acc.x, acc.y);
atomicAdd(out_.item_ptr(x_row, w_column), result);
}
else
{
atomicAdd(out_.item_ptr(x_row, w_column), acc_h);
}
}
fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)
{
// <bool use_half2, bool use_groupsize, bool use_x_map>
if (tuningParams->matmul_no_half2) {
if (block_size_z % groupsize == 0) {
if (x_map) return q4_matmul_kernel<false, true, true >;
else return q4_matmul_kernel<false, true, false>;
} else {
if (x_map) return q4_matmul_kernel<false, false, true >;
else return q4_matmul_kernel<false, false, false>;
}
} else {
if (block_size_z % groupsize == 0)
{
if (x_map) return q4_matmul_kernel<true, true, true >;
else return q4_matmul_kernel<true, true, false>;
} else {
if (x_map) return q4_matmul_kernel<true, false, true >;
else return q4_matmul_kernel<true, false, false>;
}
}
};
// Compute y = x @ w
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero,
cudaStream_t alt_stream
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
uint32_t* x_map = w->cuda_x_map;
const half* x_mapped = x;
if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)
{
CudaBuffers* buffers = get_buffers(w->device);
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
x_map = NULL;
}
int block_size_z;
if (w->width == 4096) block_size_z = 384; // 7B
else if (w->width == 11008) block_size_z = 256;
else if (w->width == 5120) block_size_z = 384; // 13B
else if (w->width == 13824) block_size_z = 256;
else if (w->width == 6656) block_size_z = 256; // 33B
else if (w->width == 17920) block_size_z = 128;
else block_size_z = 256;
//if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));
dim3 threads(THREADS_X, THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height + threads.y - 1) / threads.y,
(dim + block_size_z - 1) / block_size_z
);
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
}
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero
)
{
int height = x_height;
int dim = w->height;
int width = w->width;
cudaSetDevice(w->device);
CudaBuffers* buffers = get_buffers(w->device);
const half* x_mapped = x;
if (w->cuda_x_map)
{
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
x_mapped = buffers->temp_state;
}
w->reconstruct(buffers->temp_dq);
const half alpha = __float2half(1.0f);
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
// const float alpha = 1.0f;
// const float beta = no_zero ? 1.0f : 0.0f;
// cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
// x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
}

View File

@ -0,0 +1,37 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matmul_cuh
#define _q4_matmul_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/cuda/CUDAContext.h>
#include "q4_matrix.cuh"
#include "../tuning.h"
void q4_matmul_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
const Q4Matrix* w,
half* out,
bool no_zero = false,
cudaStream_t alt_stream = NULL
);
void q4_matmul_recons_cuda
(
ExLlamaTuning* tuningParams,
const half* x,
const int x_height,
Q4Matrix* w,
half* out,
const cublasHandle_t handle,
bool no_zero = false
);
#endif

View File

@ -0,0 +1,217 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
#include "../matrix.cuh"
using namespace std;
const int UNSHUF_BLOCKSIZE_X = 64;
const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows
vector<Q4Matrix*> g_q4_matrices;
void g_q4_keep_matrix(Q4Matrix* m)
{
g_q4_matrices.push_back(m);
}
void g_q4_free_matrices()
{
for (const auto& m : g_q4_matrices) delete m;
g_q4_matrices.clear();
}
Q4Matrix::Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
) :
height(_height),
width(_width),
groups(_groups),
device(_device)
{
cudaSetDevice(device);
cuda_qweight = _qweight;
cuda_qzeros = _qzeros;
cuda_scales = _scales;
groupsize = height / groups;
if (_g_idx) make_sequential(_g_idx);
}
Q4Matrix::~Q4Matrix()
{
}
// Make sequential
__global__ void make_sequential_kernel
(
const uint32_t* __restrict__ w,
uint32_t* __restrict__ w_new,
const uint32_t* __restrict__ x_map,
const int w_height,
const int w_width
)
{
const uint64_t* w2 = (uint64_t*) w;
uint64_t* w_new2 = (uint64_t*) w_new;
int w2_stride = w_width >> 1;
int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int w_new2_row = blockIdx.y;
int x_map_idx = w_new2_row << 3;
uint64_t dst = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
int source_row = x_map[x_map_idx++];
int w2_row = source_row >> 3;
int w2_subrow = source_row & 0x07;
int w2_row_shift = w2_subrow << 2;
int wnew2_row_shift = i << 2;
uint64_t src = w2[w2_row * w2_stride + w2_column];
src >>= w2_row_shift;
src &= 0x0000000f0000000f;
src <<= wnew2_row_shift;
dst |= src;
}
w_new2[w_new2_row * w2_stride + w2_column] = dst;
}
void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
uint32_t* cuda_new_qweight = NULL;
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
// Group histogram
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
// Group map
for (int i = 0, acc = 0; i < groups; i++)
{
short tmp = cpu_g_idx_map[i];
cpu_g_idx_map[i] = acc;
acc += tmp;
}
// X map (inverse)
for (int row = 0; row < height; row++)
{
uint32_t target_group = cpu_g_idx[row];
uint32_t target_row = cpu_g_idx_map[target_group];
cpu_g_idx_map[target_group]++;
cpu_x_map_inv[row] = target_row;
}
// X map
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
// Move to CUDA
cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);
// Rearrange rows in w
dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);
make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);
// Replace qweights
cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
// Cleanup
cudaDeviceSynchronize();
cudaFree(cuda_new_qweight);
free(cpu_g_idx_map);
free(cpu_x_map);
free(cpu_x_map_inv);
}
__global__ void reconstruct_kernel
(
const uint32_t* __restrict__ w,
half* __restrict__ out, // (y)
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int height,
const int width,
const int groupsize
)
{
// Start of block
int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
// Views
MatrixView_q4_column w_(w, height, width);
MatrixView_half_rw out_(out, height, width);
MatrixView_half w_scales_(w_scales, height / groupsize, width);
MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);
// Groupsize version
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);
#pragma unroll
for (int s = 0; s < 32; s += 4)
{
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
*out_ptr = w_item; out_ptr += out_.width;
}
}
void Q4Matrix::reconstruct(half* out)
{
dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);
dim3 blocks
(
(width + threads.x - 1) / threads.x,
(height / 8 + threads.y - 1) / threads.y,
1
);
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}

View File

@ -0,0 +1,53 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _q4_matrix_cuh
#define _q4_matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
class Q4Matrix
{
public:
int device;
int height;
int width;
int groups;
int groupsize;
uint32_t* cuda_qweight = NULL;
uint32_t* cuda_qzeros = NULL;
half* cuda_scales = NULL;
uint32_t* cuda_x_map = NULL;
Q4Matrix
(
const int _height,
const int _width,
const int _groups,
uint32_t* _qweight,
uint32_t* _qzeros,
half* _scales,
uint32_t* _g_idx,
const int _device
);
~Q4Matrix();
void reconstruct(half* out);
private:
void make_sequential(const uint32_t* cpu_g_idx);
};
void g_q4_keep_matrix(Q4Matrix* m);
void g_q4_free_matrices();
#endif

View File

@ -0,0 +1,249 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}

View File

@ -0,0 +1,294 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif

View File

@ -0,0 +1,13 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif

View File

@ -0,0 +1,29 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define cudaUnspecified cudaErrorApiFailureBase
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif

View File

@ -0,0 +1,19 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="exllama_kernels",
ext_modules=[
CUDAExtension(
name="exllama_kernels",
sources=[
"exllama_kernels/exllama_ext.cpp",
"exllama_kernels/cuda_buffers.cu",
"exllama_kernels/cuda_func/column_remap.cu",
"exllama_kernels/cuda_func/q4_matmul.cu",
"exllama_kernels/cuda_func/q4_matrix.cu"
],
)
],
cmdclass={"build_ext": BuildExtension},
)

View File

@ -383,7 +383,6 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.config = config
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()

View File

@ -20,7 +20,6 @@ from text_generation_server.utils.layers import (
) )
from safetensors import SafetensorError from safetensors import SafetensorError
def load_multi_mqa( def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
): ):
@ -50,6 +49,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop] q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :] kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1) qweight = torch.cat([q_tensor, kv_tensor], dim=1)
qweight = qweight.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales") slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -60,6 +60,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop] q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :] kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1) scales = torch.cat([q_tensor, kv_tensor], dim=1)
scales = scales.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -70,21 +71,15 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop] q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
try: g_idx = g_idx.to(device=weights.device)
bits = weights.get_tensor("gptq_bits").item() bits, groupsize = weights._get_gptq_qparams()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS")) from text_generation_server.utils.layers import HAS_EXLLAMA
groupsize = int(os.getenv("GPTQ_GROUPSIZE")) use_exllama = HAS_EXLLAMA
except Exception: weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if bias: if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias") slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
@ -97,6 +92,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[start:stop] q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :] kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else: else:
@ -361,7 +357,6 @@ class Block(nn.Module):
max_s, max_s,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn( hidden_states = self.attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,

View File

@ -6,7 +6,6 @@ import torch.distributed
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict

View File

@ -3,14 +3,13 @@ import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__( def __init__(
self, self,

View File

@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache self.cache = cache
@ -140,6 +141,16 @@ def serve(
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
raise raise
if quantize == "gptq":
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
create_exllama_buffers()
except ImportError:
pass
server = aio.server( server = aio.server(
interceptors=[ interceptors=[
ExceptionInterceptor(), ExceptionInterceptor(),

View File

@ -0,0 +1,121 @@
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def create_exllama_buffers():
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
if ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
# Dummy
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -1,3 +1,4 @@
import os
import torch import torch
import torch.distributed import torch.distributed
@ -16,7 +17,15 @@ except ImportError:
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_EXLLAMA = True
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA=False
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
except ImportError:
HAS_EXLLAMA = False
from typing import Optional
# Monkey patching # Monkey patching
@classmethod @classmethod
@ -144,21 +153,24 @@ def get_linear(weight, bias, quantize):
linear.bias = nn.Parameter(bias) linear.bias = nn.Parameter(bias)
elif quantize == "gptq": elif quantize == "gptq":
try: try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception: except Exception:
raise NotImplementedError( raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated." f"The passed weight is not `gptq` compatible, loader needs to be updated."
) )
linear = QuantLinear( if use_exllama:
qweight, linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
qzeros, else:
scales, linear = QuantLinear(
g_idx, qweight,
bias, qzeros,
bits, scales,
groupsize, g_idx,
) bias,
bits,
groupsize,
)
else: else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear return linear

View File

@ -1,7 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import List, Dict, Optional from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError from safetensors import safe_open, SafetensorError
import torch import torch
from loguru import logger
class Weights: class Weights:
@ -127,18 +128,8 @@ class Weights:
torch.testing.assert_close(w2, w[0]) torch.testing.assert_close(w2, w[0])
g_idx = w[0] g_idx = w[0]
try: bits, groupsize = self._get_gptq_qparams()
bits = self.get_tensor("gptq_bits").item() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
else: else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim) weight = torch.cat(w, dim=dim)
@ -146,29 +137,74 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str): def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq": if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_qparams()
if bits != 4:
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
try: try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0) qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS")) from text_generation_server.utils.layers import HAS_EXLLAMA
groupsize = int(os.getenv("GPTQ_GROUPSIZE")) if use_exllama:
except Exception: if not HAS_EXLLAMA:
raise e logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
use_exllama = False
else:
logger.info("Using exllama kernels")
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
if use_exllama:
if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
# qzeros = self.get_tensor(f"{prefix}.qzeros")
# scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else: else:
weight = self.get_sharded(f"{prefix}.weight", dim=1) weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight return weight
def _get_gptq_qparams(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
return bits, groupsize