ROCm AWQ support (#1514)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> This PR adds the possibility to run AWQ models with Exllama/GPTQ kernels, specifically for ROCm devices that support Exllama kernels but not AWQ's GEMM. This is done by : - un-packing, reordering and re-packing AWQ weights when `--quantize gptq` but the model's `quant_method=awq`. - avoiding overflows when adding 1 to zeros in exllama and triton. Ref: https://github.com/casper-hansen/AutoAWQ/pull/313 ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
c5ef81bed5
commit
a4e5801684
|
@ -11,78 +11,79 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.59375,
|
||||
"logprob": -9.7890625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"logprob": -9.625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3867188,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8183594,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6367188,
|
||||
"logprob": -2.3359375,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"logprob": -1.8779297,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6542969,
|
||||
"logprob": -1.2744141,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056121826,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.01600647,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87939453,
|
||||
"logprob": -1.6933594,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7529297,
|
||||
"logprob": -1.4648438,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.2980957,
|
||||
"logprob": -0.15600586,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8027344,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.23022461,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.0069885254,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.02218628,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
}
|
||||
|
|
|
@ -11,12 +11,12 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"logprob": -9.6015625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
|
@ -24,13 +24,13 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": -1.1640625,
|
||||
"logprob": -1.5625,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 1454,
|
||||
"logprob": -0.07543945,
|
||||
"logprob": -0.20410156,
|
||||
"special": false,
|
||||
"text": "for"
|
||||
},
|
||||
|
@ -54,19 +54,19 @@
|
|||
},
|
||||
{
|
||||
"id": 396,
|
||||
"logprob": -0.2956543,
|
||||
"logprob": -0.27685547,
|
||||
"special": false,
|
||||
"text": " #"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.52734375,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29900,
|
||||
"logprob": -0.6899414,
|
||||
"logprob": -0.80615234,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
|
@ -77,12 +77,13 @@
|
|||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -1.5068359,
|
||||
"id": 29955,
|
||||
"logprob": -1.0751953,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
"text": "7"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request-for-comment: #2014"
|
||||
"generated_text": "Test request-for-comment: #2017"
|
||||
}
|
||||
|
|
|
@ -12,80 +12,81 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"logprob": -9.828125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.671875,
|
||||
"logprob": -9.609375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3828125,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8105469,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"logprob": -2.3300781,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0546875,
|
||||
"logprob": -1.8740234,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6513672,
|
||||
"logprob": -1.2646484,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056365967,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016082764,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.87841797,
|
||||
"logprob": -1.7158203,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7548828,
|
||||
"logprob": -1.4667969,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29711914,
|
||||
"logprob": -0.15344238,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.81591797,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22973633,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007045746,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021957397,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -100,80 +101,81 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"logprob": -9.59375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3828125,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.828125,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6386719,
|
||||
"logprob": -2.3378906,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"logprob": -1.8779297,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.6542969,
|
||||
"logprob": -1.2636719,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.055877686,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016021729,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8769531,
|
||||
"logprob": -1.6992188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7583008,
|
||||
"logprob": -1.4589844,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29833984,
|
||||
"logprob": -0.15344238,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.79052734,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22937012,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007041931,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.022140503,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -188,80 +190,81 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.671875,
|
||||
"logprob": -9.609375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3847656,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8144531,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"logprob": -2.3261719,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0527344,
|
||||
"logprob": -1.8730469,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.65478516,
|
||||
"logprob": -1.2587891,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.056243896,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.016143799,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8808594,
|
||||
"logprob": -1.6894531,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.75341797,
|
||||
"logprob": -1.46875,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.2956543,
|
||||
"logprob": -0.1541748,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.80322266,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22912598,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.0070495605,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021606445,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -276,79 +279,80 @@
|
|||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.6015625,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6640625,
|
||||
"logprob": -9.6015625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -2.3769531,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -2.8183594,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6396484,
|
||||
"logprob": -2.3320312,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.0546875,
|
||||
"logprob": -1.875,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.65478516,
|
||||
"logprob": -1.2646484,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 29918,
|
||||
"logprob": -0.05557251,
|
||||
"special": false,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 5338,
|
||||
"logprob": -0.01612854,
|
||||
"special": false,
|
||||
"text": "uri"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8730469,
|
||||
"logprob": -1.6884766,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.7519531,
|
||||
"logprob": -1.4589844,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.29785156,
|
||||
"logprob": -0.15185547,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.79833984,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22827148,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.006996155,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021560669,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "_uri\nTest request_uri\nTest request"
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -1,193 +1,194 @@
|
|||
{
|
||||
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 589,
|
||||
"text": "def",
|
||||
"logprob": null
|
||||
"logprob": null,
|
||||
"text": "def"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"text": " ge",
|
||||
"logprob": -9.0234375
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"text": "ometric",
|
||||
"logprob": -9.0859375
|
||||
"logprob": -7.5859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"text": "_",
|
||||
"logprob": -0.25878906
|
||||
"logprob": -0.2668457,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"text": "mean",
|
||||
"logprob": -2.2109375
|
||||
"logprob": -1.6416016,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.30371094
|
||||
"logprob": -0.22705078,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -5.6054688
|
||||
"logprob": -5.2304688,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"text": ":",
|
||||
"logprob": -3.0722656
|
||||
"logprob": -3.0976562,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"text": " List",
|
||||
"logprob": -0.6879883
|
||||
"logprob": -1.1044922,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"text": "[",
|
||||
"logprob": -0.38500977
|
||||
"logprob": -0.14294434,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"text": "float",
|
||||
"logprob": -0.984375
|
||||
"logprob": -0.32299805,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"text": "]):",
|
||||
"logprob": -2.5351562
|
||||
"logprob": -2.8164062,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"text": "\n ",
|
||||
"logprob": -1.1738281,
|
||||
"special": false
|
||||
"logprob": -0.1282959,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"text": " return",
|
||||
"logprob": -0.95947266,
|
||||
"special": false
|
||||
"id": 1524,
|
||||
"logprob": -0.97998047,
|
||||
"special": false,
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"text": " sum",
|
||||
"logprob": -1.4199219,
|
||||
"special": false
|
||||
"id": 284,
|
||||
"logprob": -0.7006836,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.085876465,
|
||||
"special": false
|
||||
"id": 14883,
|
||||
"logprob": -2.1933594,
|
||||
"special": false,
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.09875488,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"text": ")",
|
||||
"logprob": -0.30517578,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"text": " /",
|
||||
"logprob": -0.42089844,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"text": " len",
|
||||
"logprob": -0.042053223,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.0011806488,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.0005259514,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"text": ")",
|
||||
"logprob": -0.0017633438,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 478,
|
||||
"text": "\n\n",
|
||||
"logprob": -0.69189453,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"text": "\n",
|
||||
"logprob": -0.041870117,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"text": "def",
|
||||
"logprob": -0.27856445,
|
||||
"special": false
|
||||
"id": 322,
|
||||
"logprob": -0.2697754,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"text": " ge",
|
||||
"logprob": -1.7255859,
|
||||
"special": false
|
||||
"logprob": -0.0836792,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"text": "ometric",
|
||||
"logprob": -0.011291504,
|
||||
"special": false
|
||||
"logprob": -0.018737793,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"text": "_",
|
||||
"logprob": -0.008430481,
|
||||
"special": false
|
||||
"id": 5651,
|
||||
"logprob": -0.028640747,
|
||||
"special": false,
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"text": "mean",
|
||||
"logprob": -0.025787354,
|
||||
"special": false
|
||||
"id": 432,
|
||||
"logprob": -0.29467773,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"text": "(",
|
||||
"logprob": -0.073913574,
|
||||
"special": false
|
||||
"id": 312,
|
||||
"logprob": -0.31518555,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"text": "L",
|
||||
"logprob": -0.09967041,
|
||||
"special": false
|
||||
"id": 1149,
|
||||
"logprob": -0.20605469,
|
||||
"special": false,
|
||||
"text": " list"
|
||||
},
|
||||
{
|
||||
"id": 432,
|
||||
"logprob": -0.23254395,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 7515,
|
||||
"logprob": -0.4489746,
|
||||
"special": false,
|
||||
"text": " numbers"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.6044922,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 446,
|
||||
"logprob": -0.63964844,
|
||||
"special": false,
|
||||
"text": "\n\n "
|
||||
},
|
||||
{
|
||||
"id": 499,
|
||||
"logprob": -1.1953125,
|
||||
"special": false,
|
||||
"text": " :"
|
||||
},
|
||||
{
|
||||
"id": 753,
|
||||
"logprob": -0.03515625,
|
||||
"special": false,
|
||||
"text": "param"
|
||||
},
|
||||
{
|
||||
"id": 498,
|
||||
"logprob": -0.06311035,
|
||||
"special": false,
|
||||
"text": " L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -0.003414154,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -1.3310547,
|
||||
"special": false,
|
||||
"text": " List"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
|
||||
}
|
||||
|
|
|
@ -11,57 +11,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"logprob": -7.5898438,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25830078,
|
||||
"logprob": -0.26586914,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.1875,
|
||||
"logprob": -1.6347656,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30004883,
|
||||
"logprob": -0.22705078,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.6171875,
|
||||
"logprob": -5.2382812,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.078125,
|
||||
"logprob": -3.0996094,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.68066406,
|
||||
"logprob": -1.1025391,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.38745117,
|
||||
"logprob": -0.14294434,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.9453125,
|
||||
"logprob": -0.32226562,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.5371094,
|
||||
"logprob": -2.8164062,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -69,19 +69,19 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -0.051635742,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": 0.0,
|
||||
"logprob": -1.3134766,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
},
|
||||
{
|
||||
"id": 11665,
|
||||
"logprob": -1.2236328,
|
||||
"logprob": -0.10021973,
|
||||
"special": false,
|
||||
"text": " reduce"
|
||||
},
|
||||
|
@ -129,7 +129,7 @@
|
|||
},
|
||||
{
|
||||
"id": 319,
|
||||
"logprob": 0.0,
|
||||
"logprob": -0.42871094,
|
||||
"special": false,
|
||||
"text": " *"
|
||||
},
|
||||
|
@ -158,36 +158,37 @@
|
|||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": -0.12695312,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"id": 1115,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": " **"
|
||||
},
|
||||
{
|
||||
"id": 589,
|
||||
"id": 308,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "def"
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"id": 35,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " ge"
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"id": 32,
|
||||
"logprob": -0.31323242,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 34,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "ometric"
|
||||
"text": "0"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
|
||||
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
|
||||
}
|
||||
|
|
|
@ -12,57 +12,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"logprob": -7.5820312,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"logprob": -0.26708984,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"logprob": -1.6386719,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"logprob": -0.22717285,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"logprob": -5.234375,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"logprob": -3.1015625,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"logprob": -1.1083984,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"logprob": -0.14294434,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"logprob": -0.32592773,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"logprob": -2.8164062,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -70,67 +70,68 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"logprob": -0.12817383,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.91796875,
|
||||
"id": 1524,
|
||||
"logprob": -0.9863281,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.3291016,
|
||||
"id": 284,
|
||||
"logprob": -0.7011719,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.08062744,
|
||||
"id": 14883,
|
||||
"logprob": -2.2050781,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.097717285,
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.29003906,
|
||||
"id": 3226,
|
||||
"logprob": -0.08465576,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.34958984,
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.03829956,
|
||||
"id": 5651,
|
||||
"logprob": -0.028625488,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011987686,
|
||||
"id": 432,
|
||||
"logprob": -0.29418945,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.00050878525,
|
||||
"id": 312,
|
||||
"logprob": -0.3161621,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " a"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -145,57 +146,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"logprob": -7.59375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25878906,
|
||||
"logprob": -0.26953125,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.2109375,
|
||||
"logprob": -1.640625,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30371094,
|
||||
"logprob": -0.22705078,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.6054688,
|
||||
"logprob": -5.234375,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0722656,
|
||||
"logprob": -3.1132812,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6879883,
|
||||
"logprob": -1.1123047,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.38500977,
|
||||
"logprob": -0.14294434,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.984375,
|
||||
"logprob": -0.32299805,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.5351562,
|
||||
"logprob": -2.8164062,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -203,67 +204,68 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1738281,
|
||||
"logprob": -0.12854004,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.9584961,
|
||||
"id": 1524,
|
||||
"logprob": -0.9897461,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.4169922,
|
||||
"id": 284,
|
||||
"logprob": -0.69970703,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.085876465,
|
||||
"id": 14883,
|
||||
"logprob": -2.2050781,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0982666,
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.3022461,
|
||||
"id": 3226,
|
||||
"logprob": -0.08496094,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.40504883,
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.041656494,
|
||||
"id": 5651,
|
||||
"logprob": -0.029037476,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011844635,
|
||||
"id": 432,
|
||||
"logprob": -0.2939453,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0005264282,
|
||||
"id": 312,
|
||||
"logprob": -0.31591797,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " a"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -278,57 +280,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"logprob": -7.5859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"logprob": -0.26586914,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"logprob": -1.6347656,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"logprob": -0.22766113,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"logprob": -5.2265625,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"logprob": -3.0976562,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"logprob": -1.1025391,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"logprob": -0.1427002,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"logprob": -0.32592773,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"logprob": -2.8164062,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -336,67 +338,68 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"logprob": -0.13012695,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.9165039,
|
||||
"id": 1524,
|
||||
"logprob": -0.98046875,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.328125,
|
||||
"id": 284,
|
||||
"logprob": -0.69921875,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.07946777,
|
||||
"id": 14883,
|
||||
"logprob": -2.1992188,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.09820557,
|
||||
"id": 322,
|
||||
"logprob": -0.2668457,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.28930664,
|
||||
"id": 3226,
|
||||
"logprob": -0.083496094,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.34592773,
|
||||
"id": 21017,
|
||||
"logprob": -0.01902771,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.038330078,
|
||||
"id": 5651,
|
||||
"logprob": -0.029006958,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011940002,
|
||||
"id": 432,
|
||||
"logprob": -0.29248047,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.00050878525,
|
||||
"id": 312,
|
||||
"logprob": -0.3161621,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " a"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -411,57 +414,57 @@
|
|||
},
|
||||
{
|
||||
"id": 3226,
|
||||
"logprob": -9.0234375,
|
||||
"logprob": -8.5859375,
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 21017,
|
||||
"logprob": -9.0859375,
|
||||
"logprob": -7.5859375,
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -0.25927734,
|
||||
"logprob": -0.26904297,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 6009,
|
||||
"logprob": -2.25,
|
||||
"logprob": -1.6386719,
|
||||
"text": "mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.30126953,
|
||||
"logprob": -0.22705078,
|
||||
"text": "("
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -5.7539062,
|
||||
"logprob": -5.234375,
|
||||
"text": "L"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"logprob": -3.0878906,
|
||||
"logprob": -3.1132812,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1682,
|
||||
"logprob": -0.6845703,
|
||||
"logprob": -1.1074219,
|
||||
"text": " List"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"logprob": -0.3918457,
|
||||
"logprob": -0.14477539,
|
||||
"text": "["
|
||||
},
|
||||
{
|
||||
"id": 1808,
|
||||
"logprob": -0.8798828,
|
||||
"logprob": -0.3256836,
|
||||
"text": "float"
|
||||
},
|
||||
{
|
||||
"id": 10794,
|
||||
"logprob": -2.4980469,
|
||||
"logprob": -2.8027344,
|
||||
"text": "]):"
|
||||
}
|
||||
],
|
||||
|
@ -469,66 +472,67 @@
|
|||
"tokens": [
|
||||
{
|
||||
"id": 284,
|
||||
"logprob": -1.1533203,
|
||||
"logprob": -0.12915039,
|
||||
"special": false,
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 442,
|
||||
"logprob": -0.91259766,
|
||||
"id": 1524,
|
||||
"logprob": -0.98535156,
|
||||
"special": false,
|
||||
"text": " return"
|
||||
"text": " \"\"\""
|
||||
},
|
||||
{
|
||||
"id": 3632,
|
||||
"logprob": -1.3251953,
|
||||
"id": 284,
|
||||
"logprob": -0.69921875,
|
||||
"special": false,
|
||||
"text": " sum"
|
||||
"text": "\n "
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.08062744,
|
||||
"id": 14883,
|
||||
"logprob": -2.2011719,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " Calculate"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.09906006,
|
||||
"id": 322,
|
||||
"logprob": -0.26708984,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 27,
|
||||
"logprob": -0.28979492,
|
||||
"id": 3226,
|
||||
"logprob": -0.08502197,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
"text": " ge"
|
||||
},
|
||||
{
|
||||
"id": 517,
|
||||
"logprob": -0.35958984,
|
||||
"id": 21017,
|
||||
"logprob": -0.019012451,
|
||||
"special": false,
|
||||
"text": " /"
|
||||
"text": "ometric"
|
||||
},
|
||||
{
|
||||
"id": 2069,
|
||||
"logprob": -0.038604736,
|
||||
"id": 5651,
|
||||
"logprob": -0.028625488,
|
||||
"special": false,
|
||||
"text": " len"
|
||||
"text": " mean"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"logprob": -0.0011901855,
|
||||
"id": 432,
|
||||
"logprob": -0.29589844,
|
||||
"special": false,
|
||||
"text": "("
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 62,
|
||||
"logprob": -0.0005078316,
|
||||
"id": 312,
|
||||
"logprob": -0.31591797,
|
||||
"special": false,
|
||||
"text": "L"
|
||||
"text": " a"
|
||||
}
|
||||
]
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n return sum(L) / len(L"
|
||||
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel
|
|||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
|
@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel
|
|||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
|
@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel
|
|||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
|
@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel
|
|||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
|
|
|
@ -189,7 +189,7 @@ __global__ void reconstruct_kernel
|
|||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
|
|
@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
|
@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel
|
|||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
|
|
@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel
|
|||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel
|
|||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
|
|
|
@ -69,9 +69,17 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
bits, groupsize, _ = weights._get_gptq_params()
|
||||
bits, groupsize, _, quant_method, = weights._get_gptq_params()
|
||||
if quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
elif quant_method == "awq":
|
||||
g_idx = None
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
|
||||
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
|
||||
|
||||
def pack(imatrix: torch.Tensor, direction: str = "column"):
|
||||
"""
|
||||
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||
Args:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
direction (str): direction of packing, either "column" or "row"
|
||||
Returns:
|
||||
qmatrix (torch.Tensor): packed matrix of integers
|
||||
"""
|
||||
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
|
||||
|
||||
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||
|
||||
if direction == "column":
|
||||
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
|
||||
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
|
||||
|
||||
elif direction == "row":
|
||||
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
|
||||
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
|
||||
|
||||
qmatrix = qmatrix.to(torch.int32)
|
||||
|
||||
return qmatrix
|
||||
|
||||
|
||||
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
|
||||
"""
|
||||
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
|
||||
Args:
|
||||
qmatrix (torch.Tensor): matrix of packed integers
|
||||
direction (str): direction of unpacking, either "column" or "row"
|
||||
Returns:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
"""
|
||||
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
|
||||
|
||||
if direction == "column":
|
||||
imatrix = torch.bitwise_right_shift(
|
||||
qmatrix[:, :, None], shifts[None, None, :]
|
||||
).view(qmatrix.shape[0], -1)
|
||||
|
||||
elif direction == "row":
|
||||
imatrix = torch.bitwise_right_shift(
|
||||
qmatrix[:, None, :], shifts[None, :, None]
|
||||
).view(-1, qmatrix.shape[-1])
|
||||
|
||||
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
|
||||
|
||||
return imatrix
|
||||
|
||||
|
||||
def apply_order(
|
||||
imatrix: torch.Tensor,
|
||||
direction: str = "column",
|
||||
order: List[int] = AWQ_PACK_ORDER,
|
||||
):
|
||||
"""
|
||||
Applies the order to a 4-bit integer matrix.
|
||||
Args:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
direction (str): direction of applying order, either "column" or "row"
|
||||
order (List[int]): order to apply, default is AWQ_PACK_ORDER
|
||||
Returns:
|
||||
imatrix (torch.Tensor): matrix of integers
|
||||
"""
|
||||
if direction == "column":
|
||||
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
|
||||
elif direction == "row":
|
||||
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
|
||||
|
||||
return imatrix
|
||||
|
||||
|
||||
def fast_awq_to_gptq(qweight, qzeros):
|
||||
# awq uses column packing for both weights and zeros
|
||||
izeros = unpack(qzeros, direction="column")
|
||||
iweights = unpack(qweight, direction="column")
|
||||
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
|
||||
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
|
||||
izeros = izeros - 1
|
||||
# exllama uses row packing for weights and column packing for zeros
|
||||
qzeros = pack(izeros, direction="column")
|
||||
qweight = pack(iweights, direction="row")
|
||||
|
||||
return qweight, qzeros
|
|
@ -182,7 +182,7 @@ try:
|
|||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = zeros + 1
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
|
|
@ -349,6 +349,13 @@ def get_linear(weight, bias, quantize):
|
|||
raise NotImplementedError(
|
||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||
)
|
||||
if IS_ROCM_SYSTEM:
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
if not HAS_AWQ:
|
||||
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
|
||||
linear = WQLinear(
|
||||
w_bit=bits,
|
||||
group_size=groupsize,
|
||||
|
|
|
@ -46,7 +46,6 @@ class Weights:
|
|||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
|
||||
names = [tensor_name]
|
||||
if self.prefix is not None:
|
||||
prefixed = f"{self.prefix}.{tensor_name}"
|
||||
|
@ -154,15 +153,30 @@ class Weights:
|
|||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
bits, groupsize, _, quant_method = self._get_gptq_params()
|
||||
|
||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||
scales = self._get_qweight(f"{prefix}.scales")
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
if quantize == "gptq":
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
elif quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = (
|
||||
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
slice_ = self._get_slice(f"{prefix}.weight")
|
||||
|
@ -204,20 +218,40 @@ class Weights:
|
|||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
if quantize == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
else:
|
||||
g_idx = None
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
|
||||
)
|
||||
|
||||
if quantize == "gptq" and quant_method == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif quantize == "gptq" and quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // bits), device=qweight.device
|
||||
)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
|
@ -243,7 +277,7 @@ class Weights:
|
|||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "gptq":
|
||||
use_exllama = True
|
||||
bits, groupsize, desc_act = self._get_gptq_params()
|
||||
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
|
||||
|
||||
if bits != 4:
|
||||
use_exllama = False
|
||||
|
@ -252,8 +286,19 @@ class Weights:
|
|||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if quant_method == "gptq":
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
elif quant_method == "awq":
|
||||
g_idx = None
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
|
@ -269,13 +314,6 @@ class Weights:
|
|||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
|
||||
|
||||
if use_exllama:
|
||||
|
@ -289,8 +327,6 @@ class Weights:
|
|||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if use_exllama and groupsize != -1:
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
@ -298,12 +334,31 @@ class Weights:
|
|||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama:
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.utils.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // bits), device=qweight.device
|
||||
)
|
||||
// groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize, _ = self._get_gptq_params()
|
||||
bits, groupsize, _, _ = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
|
@ -322,20 +377,22 @@ class Weights:
|
|||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int]:
|
||||
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
desc_act = False
|
||||
quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
quant_method = getattr(self, "quant_method", "gptq")
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize, desc_act
|
||||
return bits, groupsize, desc_act, quant_method
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
|
@ -351,6 +408,7 @@ class Weights:
|
|||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
self.quant_method = data["quantization_config"]["quant_method"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
|
@ -365,6 +423,8 @@ class Weights:
|
|||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
|
@ -379,5 +439,7 @@ class Weights:
|
|||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue