ROCm AWQ support (#1514)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> This PR adds the possibility to run AWQ models with Exllama/GPTQ kernels, specifically for ROCm devices that support Exllama kernels but not AWQ's GEMM. This is done by : - un-packing, reordering and re-packing AWQ weights when `--quantize gptq` but the model's `quant_method=awq`. - avoiding overflows when adding 1 to zeros in exllama and triton. Ref: https://github.com/casper-hansen/AutoAWQ/pull/313 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
c5ef81bed5
commit
a4e5801684
|
@ -11,78 +11,79 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.59375,
|
"logprob": -9.7890625,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.6640625,
|
"logprob": -9.625,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -2.3867188,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -2.8183594,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.6367188,
|
"logprob": -2.3359375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -1.0527344,
|
"logprob": -1.8779297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.6542969,
|
"logprob": -1.2744141,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -0.056121826,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -0.01600647,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.87939453,
|
"logprob": -1.6933594,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -0.7529297,
|
"logprob": -1.4648438,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.2980957,
|
"logprob": -0.15600586,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.8027344,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.23022461,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.0069885254,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.02218628,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.6015625,
|
"logprob": -9.84375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.6640625,
|
"logprob": -9.6015625,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,13 +24,13 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 29899,
|
"id": 29899,
|
||||||
"logprob": -1.1640625,
|
"logprob": -1.5625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "-"
|
"text": "-"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1454,
|
"id": 1454,
|
||||||
"logprob": -0.07543945,
|
"logprob": -0.20410156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "for"
|
"text": "for"
|
||||||
},
|
},
|
||||||
|
@ -54,19 +54,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 396,
|
"id": 396,
|
||||||
"logprob": -0.2956543,
|
"logprob": -0.27685547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " #"
|
"text": " #"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.52734375,
|
"logprob": -0.4970703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29900,
|
"id": 29900,
|
||||||
"logprob": -0.6899414,
|
"logprob": -0.80615234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "0"
|
"text": "0"
|
||||||
},
|
},
|
||||||
|
@ -77,12 +77,13 @@
|
||||||
"text": "1"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29946,
|
"id": 29955,
|
||||||
"logprob": -1.5068359,
|
"logprob": -1.0751953,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "4"
|
"text": "7"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request-for-comment: #2014"
|
"generated_text": "Test request-for-comment: #2017"
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,80 +12,81 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.6015625,
|
"logprob": -9.828125,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.671875,
|
"logprob": -9.609375,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -2.3828125,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -2.8105469,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.6396484,
|
"logprob": -2.3300781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -1.0546875,
|
"logprob": -1.8740234,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.6513672,
|
"logprob": -1.2646484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -0.056365967,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -0.016082764,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.87841797,
|
"logprob": -1.7158203,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -0.7548828,
|
"logprob": -1.4667969,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.29711914,
|
"logprob": -0.15344238,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.81591797,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22973633,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.007045746,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021957397,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -100,80 +101,81 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.6015625,
|
"logprob": -9.84375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.6640625,
|
"logprob": -9.59375,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -2.3828125,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -2.828125,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.6386719,
|
"logprob": -2.3378906,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -1.0527344,
|
"logprob": -1.8779297,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.6542969,
|
"logprob": -1.2636719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -0.055877686,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -0.016021729,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8769531,
|
"logprob": -1.6992188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -0.7583008,
|
"logprob": -1.4589844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.29833984,
|
"logprob": -0.15344238,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.79052734,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22937012,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.007041931,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.022140503,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -188,80 +190,81 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.6015625,
|
"logprob": -9.84375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.671875,
|
"logprob": -9.609375,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -2.3847656,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -2.8144531,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.6396484,
|
"logprob": -2.3261719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -1.0527344,
|
"logprob": -1.8730469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.65478516,
|
"logprob": -1.2587891,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -0.056243896,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -0.016143799,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8808594,
|
"logprob": -1.6894531,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -0.75341797,
|
"logprob": -1.46875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.2956543,
|
"logprob": -0.1541748,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.80322266,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22912598,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.0070495605,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021606445,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -276,79 +279,80 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4321,
|
"id": 4321,
|
||||||
"logprob": -9.6015625,
|
"logprob": -9.84375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -9.6640625,
|
"logprob": -9.6015625,
|
||||||
"text": "request"
|
"text": "request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"seed": null,
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -2.3769531,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -2.8183594,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -1.6396484,
|
"logprob": -2.3320312,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -1.0546875,
|
"logprob": -1.875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.65478516,
|
"logprob": -1.2646484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 29918,
|
|
||||||
"logprob": -0.05557251,
|
|
||||||
"special": false,
|
|
||||||
"text": "_"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 5338,
|
|
||||||
"logprob": -0.01612854,
|
|
||||||
"special": false,
|
|
||||||
"text": "uri"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 13,
|
"id": 13,
|
||||||
"logprob": -0.8730469,
|
"logprob": -1.6884766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": "\n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3057,
|
"id": 3057,
|
||||||
"logprob": -0.7519531,
|
"logprob": -1.4589844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2009,
|
"id": 2009,
|
||||||
"logprob": -0.29785156,
|
"logprob": -0.15185547,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.79833984,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3057,
|
||||||
|
"logprob": -0.22827148,
|
||||||
|
"special": false,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -0.006996155,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.021560669,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,193 +1,194 @@
|
||||||
{
|
{
|
||||||
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
|
|
||||||
"details": {
|
"details": {
|
||||||
"best_of_sequences": null,
|
"best_of_sequences": null,
|
||||||
"finish_reason": "length",
|
"finish_reason": "length",
|
||||||
"generated_tokens": 20,
|
"generated_tokens": 20,
|
||||||
"seed": null,
|
|
||||||
"prefill": [
|
"prefill": [
|
||||||
{
|
{
|
||||||
"id": 589,
|
"id": 589,
|
||||||
"text": "def",
|
"logprob": null,
|
||||||
"logprob": null
|
"text": "def"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"text": " ge",
|
"logprob": -8.5859375,
|
||||||
"logprob": -9.0234375
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"text": "ometric",
|
"logprob": -7.5859375,
|
||||||
"logprob": -9.0859375
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"text": "_",
|
"logprob": -0.2668457,
|
||||||
"logprob": -0.25878906
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"text": "mean",
|
"logprob": -1.6416016,
|
||||||
"logprob": -2.2109375
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"text": "(",
|
"logprob": -0.22705078,
|
||||||
"logprob": -0.30371094
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"text": "L",
|
"logprob": -5.2304688,
|
||||||
"logprob": -5.6054688
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"text": ":",
|
"logprob": -3.0976562,
|
||||||
"logprob": -3.0722656
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"text": " List",
|
"logprob": -1.1044922,
|
||||||
"logprob": -0.6879883
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"text": "[",
|
"logprob": -0.14294434,
|
||||||
"logprob": -0.38500977
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"text": "float",
|
"logprob": -0.32299805,
|
||||||
"logprob": -0.984375
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"text": "]):",
|
"logprob": -2.8164062,
|
||||||
"logprob": -2.5351562
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"seed": null,
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"text": "\n ",
|
"logprob": -0.1282959,
|
||||||
"logprob": -1.1738281,
|
"special": false,
|
||||||
"special": false
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"text": " return",
|
"logprob": -0.97998047,
|
||||||
"logprob": -0.95947266,
|
"special": false,
|
||||||
"special": false
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"text": " sum",
|
"logprob": -0.7006836,
|
||||||
"logprob": -1.4199219,
|
"special": false,
|
||||||
"special": false
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"text": "(",
|
"logprob": -2.1933594,
|
||||||
"logprob": -0.085876465,
|
"special": false,
|
||||||
"special": false
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"text": "L",
|
"logprob": -0.2697754,
|
||||||
"logprob": -0.09875488,
|
"special": false,
|
||||||
"special": false
|
"text": " the"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 27,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -0.30517578,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 517,
|
|
||||||
"text": " /",
|
|
||||||
"logprob": -0.42089844,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2069,
|
|
||||||
"text": " len",
|
|
||||||
"logprob": -0.042053223,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 26,
|
|
||||||
"text": "(",
|
|
||||||
"logprob": -0.0011806488,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 62,
|
|
||||||
"text": "L",
|
|
||||||
"logprob": -0.0005259514,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 27,
|
|
||||||
"text": ")",
|
|
||||||
"logprob": -0.0017633438,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 478,
|
|
||||||
"text": "\n\n",
|
|
||||||
"logprob": -0.69189453,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 203,
|
|
||||||
"text": "\n",
|
|
||||||
"logprob": -0.041870117,
|
|
||||||
"special": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 589,
|
|
||||||
"text": "def",
|
|
||||||
"logprob": -0.27856445,
|
|
||||||
"special": false
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"text": " ge",
|
"logprob": -0.0836792,
|
||||||
"logprob": -1.7255859,
|
"special": false,
|
||||||
"special": false
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"text": "ometric",
|
"logprob": -0.018737793,
|
||||||
"logprob": -0.011291504,
|
"special": false,
|
||||||
"special": false
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 5651,
|
||||||
"text": "_",
|
"logprob": -0.028640747,
|
||||||
"logprob": -0.008430481,
|
"special": false,
|
||||||
"special": false
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 432,
|
||||||
"text": "mean",
|
"logprob": -0.29467773,
|
||||||
"logprob": -0.025787354,
|
"special": false,
|
||||||
"special": false
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 312,
|
||||||
"text": "(",
|
"logprob": -0.31518555,
|
||||||
"logprob": -0.073913574,
|
"special": false,
|
||||||
"special": false
|
"text": " a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 1149,
|
||||||
"text": "L",
|
"logprob": -0.20605469,
|
||||||
"logprob": -0.09967041,
|
"special": false,
|
||||||
"special": false
|
"text": " list"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 432,
|
||||||
|
"logprob": -0.23254395,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 7515,
|
||||||
|
"logprob": -0.4489746,
|
||||||
|
"special": false,
|
||||||
|
"text": " numbers"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 32,
|
||||||
|
"logprob": -0.6044922,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 446,
|
||||||
|
"logprob": -0.63964844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n\n "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 499,
|
||||||
|
"logprob": -1.1953125,
|
||||||
|
"special": false,
|
||||||
|
"text": " :"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 753,
|
||||||
|
"logprob": -0.03515625,
|
||||||
|
"special": false,
|
||||||
|
"text": "param"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 498,
|
||||||
|
"logprob": -0.06311035,
|
||||||
|
"special": false,
|
||||||
|
"text": " L"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 44,
|
||||||
|
"logprob": -0.003414154,
|
||||||
|
"special": false,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1682,
|
||||||
|
"logprob": -1.3310547,
|
||||||
|
"special": false,
|
||||||
|
"text": " List"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,57 +11,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5898438,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25830078,
|
"logprob": -0.26586914,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.1875,
|
"logprob": -1.6347656,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30004883,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.6171875,
|
"logprob": -5.2382812,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.078125,
|
"logprob": -3.0996094,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.68066406,
|
"logprob": -1.1025391,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.38745117,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.9453125,
|
"logprob": -0.32226562,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.5371094,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -69,19 +69,19 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -0.051635742,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 442,
|
||||||
"logprob": 0.0,
|
"logprob": -1.3134766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " return"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11665,
|
"id": 11665,
|
||||||
"logprob": -1.2236328,
|
"logprob": -0.10021973,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " reduce"
|
"text": " reduce"
|
||||||
},
|
},
|
||||||
|
@ -129,7 +129,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 319,
|
"id": 319,
|
||||||
"logprob": 0.0,
|
"logprob": -0.42871094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " *"
|
"text": " *"
|
||||||
},
|
},
|
||||||
|
@ -158,36 +158,37 @@
|
||||||
"text": ")"
|
"text": ")"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 203,
|
"id": 1115,
|
||||||
"logprob": -0.12695312,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 203,
|
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n"
|
"text": " **"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 589,
|
"id": 308,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "def"
|
"text": " ("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 35,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " ge"
|
"text": "1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 32,
|
||||||
|
"logprob": -0.31323242,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 34,
|
||||||
"logprob": 0.0,
|
"logprob": 0.0,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ometric"
|
"text": "0"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
|
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,57 +12,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5820312,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26708984,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6386719,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22717285,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.1015625,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1083984,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.32592773,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -70,67 +70,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.12817383,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.91796875,
|
"logprob": -0.9863281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.3291016,
|
"logprob": -0.7011719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.08062744,
|
"logprob": -2.2050781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.097717285,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.29003906,
|
"logprob": -0.08465576,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.34958984,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.03829956,
|
"logprob": -0.028625488,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011987686,
|
"logprob": -0.29418945,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.00050878525,
|
"logprob": -0.3161621,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -145,57 +146,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.59375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25878906,
|
"logprob": -0.26953125,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.2109375,
|
"logprob": -1.640625,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30371094,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.6054688,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0722656,
|
"logprob": -3.1132812,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6879883,
|
"logprob": -1.1123047,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.38500977,
|
"logprob": -0.14294434,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.984375,
|
"logprob": -0.32299805,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.5351562,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -203,67 +204,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1738281,
|
"logprob": -0.12854004,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.9584961,
|
"logprob": -0.9897461,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.4169922,
|
"logprob": -0.69970703,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.085876465,
|
"logprob": -2.2050781,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.0982666,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.3022461,
|
"logprob": -0.08496094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.40504883,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.041656494,
|
"logprob": -0.029037476,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011844635,
|
"logprob": -0.2939453,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.0005264282,
|
"logprob": -0.31591797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -278,57 +280,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26586914,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6347656,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22766113,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.2265625,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.0976562,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1025391,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.1427002,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.32592773,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8164062,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -336,67 +338,68 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.13012695,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.9165039,
|
"logprob": -0.98046875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.328125,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.07946777,
|
"logprob": -2.1992188,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.09820557,
|
"logprob": -0.2668457,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.28930664,
|
"logprob": -0.083496094,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.34592773,
|
"logprob": -0.01902771,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.038330078,
|
"logprob": -0.029006958,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011940002,
|
"logprob": -0.29248047,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.00050878525,
|
"logprob": -0.3161621,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"details": {
|
"details": {
|
||||||
|
@ -411,57 +414,57 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3226,
|
"id": 3226,
|
||||||
"logprob": -9.0234375,
|
"logprob": -8.5859375,
|
||||||
"text": " ge"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 21017,
|
"id": 21017,
|
||||||
"logprob": -9.0859375,
|
"logprob": -7.5859375,
|
||||||
"text": "ometric"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 81,
|
"id": 81,
|
||||||
"logprob": -0.25927734,
|
"logprob": -0.26904297,
|
||||||
"text": "_"
|
"text": "_"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 6009,
|
"id": 6009,
|
||||||
"logprob": -2.25,
|
"logprob": -1.6386719,
|
||||||
"text": "mean"
|
"text": "mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 26,
|
||||||
"logprob": -0.30126953,
|
"logprob": -0.22705078,
|
||||||
"text": "("
|
"text": "("
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 62,
|
||||||
"logprob": -5.7539062,
|
"logprob": -5.234375,
|
||||||
"text": "L"
|
"text": "L"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 44,
|
"id": 44,
|
||||||
"logprob": -3.0878906,
|
"logprob": -3.1132812,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1682,
|
"id": 1682,
|
||||||
"logprob": -0.6845703,
|
"logprob": -1.1074219,
|
||||||
"text": " List"
|
"text": " List"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 77,
|
"id": 77,
|
||||||
"logprob": -0.3918457,
|
"logprob": -0.14477539,
|
||||||
"text": "["
|
"text": "["
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1808,
|
"id": 1808,
|
||||||
"logprob": -0.8798828,
|
"logprob": -0.3256836,
|
||||||
"text": "float"
|
"text": "float"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 10794,
|
"id": 10794,
|
||||||
"logprob": -2.4980469,
|
"logprob": -2.8027344,
|
||||||
"text": "]):"
|
"text": "]):"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -469,66 +472,67 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 284,
|
"id": 284,
|
||||||
"logprob": -1.1533203,
|
"logprob": -0.12915039,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n "
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 442,
|
"id": 1524,
|
||||||
"logprob": -0.91259766,
|
"logprob": -0.98535156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " return"
|
"text": " \"\"\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3632,
|
"id": 284,
|
||||||
"logprob": -1.3251953,
|
"logprob": -0.69921875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " sum"
|
"text": "\n "
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 14883,
|
||||||
"logprob": -0.08062744,
|
"logprob": -2.2011719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " Calculate"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 322,
|
||||||
"logprob": -0.09906006,
|
"logprob": -0.26708984,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 27,
|
"id": 3226,
|
||||||
"logprob": -0.28979492,
|
"logprob": -0.08502197,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ")"
|
"text": " ge"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 517,
|
"id": 21017,
|
||||||
"logprob": -0.35958984,
|
"logprob": -0.019012451,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " /"
|
"text": "ometric"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2069,
|
"id": 5651,
|
||||||
"logprob": -0.038604736,
|
"logprob": -0.028625488,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " len"
|
"text": " mean"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 26,
|
"id": 432,
|
||||||
"logprob": -0.0011901855,
|
"logprob": -0.29589844,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "("
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 62,
|
"id": 312,
|
||||||
"logprob": -0.0005078316,
|
"logprob": -0.31591797,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "L"
|
"text": " a"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "\n return sum(L) / len(L"
|
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
|
@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
|
@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel
|
||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
|
@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel
|
||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||||
|
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||||
|
|
|
@ -189,7 +189,7 @@ __global__ void reconstruct_kernel
|
||||||
int group = row / groupsize;
|
int group = row / groupsize;
|
||||||
|
|
||||||
half w_scale = w_scales_.item(group, column);
|
half w_scale = w_scales_.item(group, column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
|
||||||
|
|
||||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||||
half* out_ptr = out_.item_ptr(row, column);
|
half* out_ptr = out_.item_ptr(row, column);
|
||||||
|
|
|
@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
// __syncthreads();
|
// __syncthreads();
|
||||||
|
|
||||||
|
@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
|
@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel
|
||||||
half2 y1y16[4][2];
|
half2 y1y16[4][2];
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel
|
||||||
nextgroup += groupsize;
|
nextgroup += groupsize;
|
||||||
b_gptq_qzeros_.item4(zeros, group, n);
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
b_gptq_scales_.item4_h2(scales, group, n);
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int p = 0; p < 4; p++)
|
for (int p = 0; p < 4; p++)
|
||||||
|
|
|
@ -69,9 +69,17 @@ def _load_multi_mqa_gptq(
|
||||||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||||
qzeros = qzeros.to(device=weights.device)
|
qzeros = qzeros.to(device=weights.device)
|
||||||
|
|
||||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
bits, groupsize, _, quant_method, = weights._get_gptq_params()
|
||||||
g_idx = g_idx.to(device=weights.device)
|
if quant_method == "gptq":
|
||||||
bits, groupsize, _ = weights._get_gptq_params()
|
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||||
|
g_idx = g_idx.to(device=weights.device)
|
||||||
|
elif quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||||
|
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of packing, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
qmatrix (torch.Tensor): packed matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||||
|
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||||
|
|
||||||
|
qmatrix = qmatrix.to(torch.int32)
|
||||||
|
|
||||||
|
return qmatrix
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||||
|
"""
|
||||||
|
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
qmatrix (torch.Tensor): matrix of packed integers
|
||||||
|
direction (str): direction of unpacking, either "column" or "row"
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||||
|
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, :, None], shifts[None, None, :]
|
||||||
|
).view(qmatrix.shape[0], -1)
|
||||||
|
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = torch.bitwise_right_shift(
|
||||||
|
qmatrix[:, None, :], shifts[None, :, None]
|
||||||
|
).view(-1, qmatrix.shape[-1])
|
||||||
|
|
||||||
|
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(
|
||||||
|
imatrix: torch.Tensor,
|
||||||
|
direction: str = "column",
|
||||||
|
order: List[int] = AWQ_PACK_ORDER,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Applies the order to a 4-bit integer matrix.
|
||||||
|
Args:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
direction (str): direction of applying order, either "column" or "row"
|
||||||
|
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||||
|
Returns:
|
||||||
|
imatrix (torch.Tensor): matrix of integers
|
||||||
|
"""
|
||||||
|
if direction == "column":
|
||||||
|
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||||
|
elif direction == "row":
|
||||||
|
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||||
|
|
||||||
|
return imatrix
|
||||||
|
|
||||||
|
|
||||||
|
def fast_awq_to_gptq(qweight, qzeros):
|
||||||
|
# awq uses column packing for both weights and zeros
|
||||||
|
izeros = unpack(qzeros, direction="column")
|
||||||
|
iweights = unpack(qweight, direction="column")
|
||||||
|
|
||||||
|
# Reverse the order of the iweight and izeros tensors
|
||||||
|
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||||
|
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||||
|
izeros = izeros - 1
|
||||||
|
# exllama uses row packing for weights and column packing for zeros
|
||||||
|
qzeros = pack(izeros, direction="column")
|
||||||
|
qweight = pack(iweights, direction="row")
|
||||||
|
|
||||||
|
return qweight, qzeros
|
|
@ -182,7 +182,7 @@ try:
|
||||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||||
|
|
||||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||||
zeros = zeros + 1
|
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||||
|
|
||||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||||
|
|
|
@ -349,6 +349,13 @@ def get_linear(weight, bias, quantize):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||||
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||||
|
)
|
||||||
|
if not HAS_AWQ:
|
||||||
|
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
|
||||||
linear = WQLinear(
|
linear = WQLinear(
|
||||||
w_bit=bits,
|
w_bit=bits,
|
||||||
group_size=groupsize,
|
group_size=groupsize,
|
||||||
|
|
|
@ -46,7 +46,6 @@ class Weights:
|
||||||
return self._handles[filename]
|
return self._handles[filename]
|
||||||
|
|
||||||
def get_filename(self, tensor_name: str) -> (str, str):
|
def get_filename(self, tensor_name: str) -> (str, str):
|
||||||
|
|
||||||
names = [tensor_name]
|
names = [tensor_name]
|
||||||
if self.prefix is not None:
|
if self.prefix is not None:
|
||||||
prefixed = f"{self.prefix}.{tensor_name}"
|
prefixed = f"{self.prefix}.{tensor_name}"
|
||||||
|
@ -154,15 +153,30 @@ class Weights:
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
if quantize == "gptq":
|
|
||||||
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
elif quantize == "gptq" and quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||||
|
// groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
else:
|
else:
|
||||||
g_idx = None
|
g_idx = None
|
||||||
|
|
||||||
bits, groupsize, _ = self._get_gptq_params()
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
|
@ -204,20 +218,40 @@ class Weights:
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if quantize == "gptq":
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
else:
|
|
||||||
g_idx = None
|
|
||||||
|
|
||||||
bits, groupsize, desc_act = self._get_gptq_params()
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||||
|
|
||||||
use_exllama = (
|
use_exllama = (
|
||||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if quantize == "gptq" and quant_method == "gptq":
|
||||||
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
elif quantize == "gptq" and quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // bits), device=qweight.device
|
||||||
|
)
|
||||||
|
// groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
@ -243,7 +277,7 @@ class Weights:
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
use_exllama = True
|
use_exllama = True
|
||||||
bits, groupsize, desc_act = self._get_gptq_params()
|
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||||
|
|
||||||
if bits != 4:
|
if bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
@ -252,8 +286,19 @@ class Weights:
|
||||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if quant_method == "gptq":
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
elif quant_method == "awq":
|
||||||
|
g_idx = None
|
||||||
|
|
||||||
if self.process_group.size() > 1:
|
if self.process_group.size() > 1:
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
if g_idx is not None:
|
if g_idx is not None:
|
||||||
if (
|
if (
|
||||||
not torch.equal(
|
not torch.equal(
|
||||||
|
@ -269,13 +314,6 @@ class Weights:
|
||||||
# it would require to reorder input activations that are split unto several GPUs
|
# it would require to reorder input activations that are split unto several GPUs
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
try:
|
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
|
@ -289,8 +327,6 @@ class Weights:
|
||||||
else:
|
else:
|
||||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
|
|
||||||
if use_exllama and groupsize != -1:
|
if use_exllama and groupsize != -1:
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
@ -298,12 +334,31 @@ class Weights:
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama and g_idx is not None:
|
||||||
g_idx = g_idx - g_idx[0]
|
g_idx = g_idx - g_idx[0]
|
||||||
|
|
||||||
|
if quant_method == "awq":
|
||||||
|
log_once(
|
||||||
|
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.awq.conversion_utils import (
|
||||||
|
fast_awq_to_gptq,
|
||||||
|
)
|
||||||
|
|
||||||
|
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||||
|
if use_exllama:
|
||||||
|
g_idx = None
|
||||||
|
else:
|
||||||
|
g_idx = (
|
||||||
|
torch.arange(
|
||||||
|
qweight.shape[0] * (32 // bits), device=qweight.device
|
||||||
|
)
|
||||||
|
// groupsize
|
||||||
|
).to(dtype=torch.int32)
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
bits, groupsize, _ = self._get_gptq_params()
|
bits, groupsize, _, _ = self._get_gptq_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
@ -322,20 +377,22 @@ class Weights:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
desc_act = False
|
desc_act = False
|
||||||
|
quant_method = "gptq"
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
bits = self.gptq_bits
|
bits = self.gptq_bits
|
||||||
groupsize = self.gptq_groupsize
|
groupsize = self.gptq_groupsize
|
||||||
desc_act = getattr(self, "gptq_desc_act", False)
|
desc_act = getattr(self, "gptq_desc_act", False)
|
||||||
|
quant_method = getattr(self, "quant_method", "gptq")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return bits, groupsize, desc_act
|
return bits, groupsize, desc_act, quant_method
|
||||||
|
|
||||||
def _set_gptq_params(self, model_id, revision):
|
def _set_gptq_params(self, model_id, revision):
|
||||||
filename = "config.json"
|
filename = "config.json"
|
||||||
|
@ -351,6 +408,7 @@ class Weights:
|
||||||
self.gptq_bits = data["quantization_config"]["bits"]
|
self.gptq_bits = data["quantization_config"]["bits"]
|
||||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||||
|
self.quant_method = data["quantization_config"]["quant_method"]
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quantize_config.json"
|
filename = "quantize_config.json"
|
||||||
try:
|
try:
|
||||||
|
@ -365,6 +423,8 @@ class Weights:
|
||||||
self.gptq_bits = data["bits"]
|
self.gptq_bits = data["bits"]
|
||||||
self.gptq_groupsize = data["group_size"]
|
self.gptq_groupsize = data["group_size"]
|
||||||
self.gptq_desc_act = data["desc_act"]
|
self.gptq_desc_act = data["desc_act"]
|
||||||
|
if "version" in data and data["version"] == "GEMM":
|
||||||
|
self.quant_method = "awq"
|
||||||
except Exception:
|
except Exception:
|
||||||
filename = "quant_config.json"
|
filename = "quant_config.json"
|
||||||
try:
|
try:
|
||||||
|
@ -379,5 +439,7 @@ class Weights:
|
||||||
self.gptq_bits = data["w_bit"]
|
self.gptq_bits = data["w_bit"]
|
||||||
self.gptq_groupsize = data["q_group_size"]
|
self.gptq_groupsize = data["q_group_size"]
|
||||||
self.gptq_desc_act = data["desc_act"]
|
self.gptq_desc_act = data["desc_act"]
|
||||||
|
if "version" in data and data["version"] == "GEMM":
|
||||||
|
self.quant_method = "awq"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue