ROCm AWQ support (#1514)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

This PR adds the possibility to run AWQ models with Exllama/GPTQ
kernels, specifically for ROCm devices that support Exllama kernels but
not AWQ's GEMM.

This is done by :
- un-packing, reordering and re-packing AWQ weights when `--quantize
gptq` but the model's `quant_method=awq`.
- avoiding overflows when adding 1 to zeros in exllama and triton.

Ref: https://github.com/casper-hansen/AutoAWQ/pull/313

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Ilyas Moutawwakil 2024-02-09 10:45:16 +01:00 committed by GitHub
parent c5ef81bed5
commit a4e5801684
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 737 additions and 551 deletions

View File

@ -11,78 +11,79 @@
},
{
"id": 4321,
"logprob": -9.59375,
"logprob": -9.7890625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"logprob": -9.625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3867188,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8183594,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6367188,
"logprob": -2.3359375,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"logprob": -1.8779297,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6542969,
"logprob": -1.2744141,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056121826,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.01600647,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.87939453,
"logprob": -1.6933594,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7529297,
"logprob": -1.4648438,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.2980957,
"logprob": -0.15600586,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.8027344,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.23022461,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0069885254,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.02218628,
"special": false,
"text": "\n"
}
]
],
"top_tokens": null
},
"generated_text": "_uri\nTest request_uri\nTest request"
"generated_text": "\nTest request\nTest request\nTest request\n"
}

View File

@ -11,12 +11,12 @@
},
{
"id": 4321,
"logprob": -9.6015625,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"logprob": -9.6015625,
"text": "request"
}
],
@ -24,13 +24,13 @@
"tokens": [
{
"id": 29899,
"logprob": -1.1640625,
"logprob": -1.5625,
"special": false,
"text": "-"
},
{
"id": 1454,
"logprob": -0.07543945,
"logprob": -0.20410156,
"special": false,
"text": "for"
},
@ -54,19 +54,19 @@
},
{
"id": 396,
"logprob": -0.2956543,
"logprob": -0.27685547,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.52734375,
"logprob": -0.4970703,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.6899414,
"logprob": -0.80615234,
"special": false,
"text": "0"
},
@ -77,12 +77,13 @@
"text": "1"
},
{
"id": 29946,
"logprob": -1.5068359,
"id": 29955,
"logprob": -1.0751953,
"special": false,
"text": "4"
"text": "7"
}
]
],
"top_tokens": null
},
"generated_text": "Test request-for-comment: #2014"
"generated_text": "Test request-for-comment: #2017"
}

View File

@ -12,80 +12,81 @@
},
{
"id": 4321,
"logprob": -9.6015625,
"logprob": -9.828125,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.671875,
"logprob": -9.609375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3828125,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8105469,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"logprob": -2.3300781,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0546875,
"logprob": -1.8740234,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6513672,
"logprob": -1.2646484,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056365967,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016082764,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.87841797,
"logprob": -1.7158203,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7548828,
"logprob": -1.4667969,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29711914,
"logprob": -0.15344238,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.81591797,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22973633,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021957397,
"special": false,
"text": "\n"
}
]
],
"top_tokens": null
},
"generated_text": "_uri\nTest request_uri\nTest request"
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
@ -100,80 +101,81 @@
},
{
"id": 4321,
"logprob": -9.6015625,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"logprob": -9.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3828125,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.828125,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6386719,
"logprob": -2.3378906,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"logprob": -1.8779297,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.6542969,
"logprob": -1.2636719,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.055877686,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016021729,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8769531,
"logprob": -1.6992188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7583008,
"logprob": -1.4589844,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29833984,
"logprob": -0.15344238,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.79052734,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22937012,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007041931,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022140503,
"special": false,
"text": "\n"
}
]
],
"top_tokens": null
},
"generated_text": "_uri\nTest request_uri\nTest request"
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
@ -188,80 +190,81 @@
},
{
"id": 4321,
"logprob": -9.6015625,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.671875,
"logprob": -9.609375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3847656,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8144531,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"logprob": -2.3261719,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0527344,
"logprob": -1.8730469,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.65478516,
"logprob": -1.2587891,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.056243896,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.016143799,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8808594,
"logprob": -1.6894531,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.75341797,
"logprob": -1.46875,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.2956543,
"logprob": -0.1541748,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.80322266,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22912598,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070495605,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021606445,
"special": false,
"text": "\n"
}
]
],
"top_tokens": null
},
"generated_text": "_uri\nTest request_uri\nTest request"
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
@ -276,79 +279,80 @@
},
{
"id": 4321,
"logprob": -9.6015625,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6640625,
"logprob": -9.6015625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 29918,
"logprob": -2.3769531,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -2.8183594,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -1.6396484,
"logprob": -2.3320312,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.0546875,
"logprob": -1.875,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.65478516,
"logprob": -1.2646484,
"special": false,
"text": " request"
},
{
"id": 29918,
"logprob": -0.05557251,
"special": false,
"text": "_"
},
{
"id": 5338,
"logprob": -0.01612854,
"special": false,
"text": "uri"
},
{
"id": 13,
"logprob": -0.8730469,
"logprob": -1.6884766,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.7519531,
"logprob": -1.4589844,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.29785156,
"logprob": -0.15185547,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.79833984,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22827148,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006996155,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021560669,
"special": false,
"text": "\n"
}
]
],
"top_tokens": null
},
"generated_text": "_uri\nTest request_uri\nTest request"
"generated_text": "\nTest request\nTest request\nTest request\n"
}
]

View File

@ -1,193 +1,194 @@
{
"generated_text": "\n return sum(L) / len(L)\n\n\ndef geometric_mean(L",
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 589,
"text": "def",
"logprob": null
"logprob": null,
"text": "def"
},
{
"id": 3226,
"text": " ge",
"logprob": -9.0234375
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"text": "ometric",
"logprob": -9.0859375
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"text": "_",
"logprob": -0.25878906
"logprob": -0.2668457,
"text": "_"
},
{
"id": 6009,
"text": "mean",
"logprob": -2.2109375
"logprob": -1.6416016,
"text": "mean"
},
{
"id": 26,
"text": "(",
"logprob": -0.30371094
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"text": "L",
"logprob": -5.6054688
"logprob": -5.2304688,
"text": "L"
},
{
"id": 44,
"text": ":",
"logprob": -3.0722656
"logprob": -3.0976562,
"text": ":"
},
{
"id": 1682,
"text": " List",
"logprob": -0.6879883
"logprob": -1.1044922,
"text": " List"
},
{
"id": 77,
"text": "[",
"logprob": -0.38500977
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"text": "float",
"logprob": -0.984375
"logprob": -0.32299805,
"text": "float"
},
{
"id": 10794,
"text": "]):",
"logprob": -2.5351562
"logprob": -2.8164062,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"text": "\n ",
"logprob": -1.1738281,
"special": false
"logprob": -0.1282959,
"special": false,
"text": "\n "
},
{
"id": 442,
"text": " return",
"logprob": -0.95947266,
"special": false
"id": 1524,
"logprob": -0.97998047,
"special": false,
"text": " \"\"\""
},
{
"id": 3632,
"text": " sum",
"logprob": -1.4199219,
"special": false
"id": 284,
"logprob": -0.7006836,
"special": false,
"text": "\n "
},
{
"id": 26,
"text": "(",
"logprob": -0.085876465,
"special": false
"id": 14883,
"logprob": -2.1933594,
"special": false,
"text": " Calculate"
},
{
"id": 62,
"text": "L",
"logprob": -0.09875488,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.30517578,
"special": false
},
{
"id": 517,
"text": " /",
"logprob": -0.42089844,
"special": false
},
{
"id": 2069,
"text": " len",
"logprob": -0.042053223,
"special": false
},
{
"id": 26,
"text": "(",
"logprob": -0.0011806488,
"special": false
},
{
"id": 62,
"text": "L",
"logprob": -0.0005259514,
"special": false
},
{
"id": 27,
"text": ")",
"logprob": -0.0017633438,
"special": false
},
{
"id": 478,
"text": "\n\n",
"logprob": -0.69189453,
"special": false
},
{
"id": 203,
"text": "\n",
"logprob": -0.041870117,
"special": false
},
{
"id": 589,
"text": "def",
"logprob": -0.27856445,
"special": false
"id": 322,
"logprob": -0.2697754,
"special": false,
"text": " the"
},
{
"id": 3226,
"text": " ge",
"logprob": -1.7255859,
"special": false
"logprob": -0.0836792,
"special": false,
"text": " ge"
},
{
"id": 21017,
"text": "ometric",
"logprob": -0.011291504,
"special": false
"logprob": -0.018737793,
"special": false,
"text": "ometric"
},
{
"id": 81,
"text": "_",
"logprob": -0.008430481,
"special": false
"id": 5651,
"logprob": -0.028640747,
"special": false,
"text": " mean"
},
{
"id": 6009,
"text": "mean",
"logprob": -0.025787354,
"special": false
"id": 432,
"logprob": -0.29467773,
"special": false,
"text": " of"
},
{
"id": 26,
"text": "(",
"logprob": -0.073913574,
"special": false
"id": 312,
"logprob": -0.31518555,
"special": false,
"text": " a"
},
{
"id": 62,
"text": "L",
"logprob": -0.09967041,
"special": false
"id": 1149,
"logprob": -0.20605469,
"special": false,
"text": " list"
},
{
"id": 432,
"logprob": -0.23254395,
"special": false,
"text": " of"
},
{
"id": 7515,
"logprob": -0.4489746,
"special": false,
"text": " numbers"
},
{
"id": 32,
"logprob": -0.6044922,
"special": false,
"text": "."
},
{
"id": 446,
"logprob": -0.63964844,
"special": false,
"text": "\n\n "
},
{
"id": 499,
"logprob": -1.1953125,
"special": false,
"text": " :"
},
{
"id": 753,
"logprob": -0.03515625,
"special": false,
"text": "param"
},
{
"id": 498,
"logprob": -0.06311035,
"special": false,
"text": " L"
},
{
"id": 44,
"logprob": -0.003414154,
"special": false,
"text": ":"
},
{
"id": 1682,
"logprob": -1.3310547,
"special": false,
"text": " List"
}
]
}
],
"top_tokens": null
},
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a list of numbers.\n\n :param L: List"
}

View File

@ -11,57 +11,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5898438,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25830078,
"logprob": -0.26586914,
"text": "_"
},
{
"id": 6009,
"logprob": -2.1875,
"logprob": -1.6347656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30004883,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.6171875,
"logprob": -5.2382812,
"text": "L"
},
{
"id": 44,
"logprob": -3.078125,
"logprob": -3.0996094,
"text": ":"
},
{
"id": 1682,
"logprob": -0.68066406,
"logprob": -1.1025391,
"text": " List"
},
{
"id": 77,
"logprob": -0.38745117,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.9453125,
"logprob": -0.32226562,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5371094,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -69,19 +69,19 @@
"tokens": [
{
"id": 284,
"logprob": -0.051635742,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": 0.0,
"logprob": -1.3134766,
"special": false,
"text": " return"
},
{
"id": 11665,
"logprob": -1.2236328,
"logprob": -0.10021973,
"special": false,
"text": " reduce"
},
@ -129,7 +129,7 @@
},
{
"id": 319,
"logprob": 0.0,
"logprob": -0.42871094,
"special": false,
"text": " *"
},
@ -158,36 +158,37 @@
"text": ")"
},
{
"id": 203,
"logprob": -0.12695312,
"special": false,
"text": "\n"
},
{
"id": 203,
"id": 1115,
"logprob": 0.0,
"special": false,
"text": "\n"
"text": " **"
},
{
"id": 589,
"id": 308,
"logprob": 0.0,
"special": false,
"text": "def"
"text": " ("
},
{
"id": 3226,
"id": 35,
"logprob": 0.0,
"special": false,
"text": " ge"
"text": "1"
},
{
"id": 21017,
"id": 32,
"logprob": -0.31323242,
"special": false,
"text": "."
},
{
"id": 34,
"logprob": 0.0,
"special": false,
"text": "ometric"
"text": "0"
}
]
],
"top_tokens": null
},
"generated_text": "\n return reduce(lambda x, y: x * y, L)\n\ndef geometric"
"generated_text": "\n return reduce(lambda x, y: x * y, L) ** (1.0"
}

View File

@ -12,57 +12,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5820312,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26708984,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6386719,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22717285,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.1015625,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1083984,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.32592773,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -70,67 +70,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.12817383,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91796875,
"id": 1524,
"logprob": -0.9863281,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.3291016,
"id": 284,
"logprob": -0.7011719,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.08062744,
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.097717285,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.29003906,
"id": 3226,
"logprob": -0.08465576,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.34958984,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.03829956,
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011987686,
"id": 432,
"logprob": -0.29418945,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.00050878525,
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -145,57 +146,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.59375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25878906,
"logprob": -0.26953125,
"text": "_"
},
{
"id": 6009,
"logprob": -2.2109375,
"logprob": -1.640625,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30371094,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.6054688,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0722656,
"logprob": -3.1132812,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6879883,
"logprob": -1.1123047,
"text": " List"
},
{
"id": 77,
"logprob": -0.38500977,
"logprob": -0.14294434,
"text": "["
},
{
"id": 1808,
"logprob": -0.984375,
"logprob": -0.32299805,
"text": "float"
},
{
"id": 10794,
"logprob": -2.5351562,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -203,67 +204,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1738281,
"logprob": -0.12854004,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9584961,
"id": 1524,
"logprob": -0.9897461,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.4169922,
"id": 284,
"logprob": -0.69970703,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.085876465,
"id": 14883,
"logprob": -2.2050781,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.0982666,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.3022461,
"id": 3226,
"logprob": -0.08496094,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.40504883,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.041656494,
"id": 5651,
"logprob": -0.029037476,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011844635,
"id": 432,
"logprob": -0.2939453,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.0005264282,
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -278,57 +280,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26586914,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6347656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22766113,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.2265625,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.0976562,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1025391,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.1427002,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.32592773,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8164062,
"text": "]):"
}
],
@ -336,67 +338,68 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.13012695,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.9165039,
"id": 1524,
"logprob": -0.98046875,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.328125,
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.07946777,
"id": 14883,
"logprob": -2.1992188,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.09820557,
"id": 322,
"logprob": -0.2668457,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.28930664,
"id": 3226,
"logprob": -0.083496094,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.34592773,
"id": 21017,
"logprob": -0.01902771,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.038330078,
"id": 5651,
"logprob": -0.029006958,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011940002,
"id": 432,
"logprob": -0.29248047,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.00050878525,
"id": 312,
"logprob": -0.3161621,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
},
{
"details": {
@ -411,57 +414,57 @@
},
{
"id": 3226,
"logprob": -9.0234375,
"logprob": -8.5859375,
"text": " ge"
},
{
"id": 21017,
"logprob": -9.0859375,
"logprob": -7.5859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.25927734,
"logprob": -0.26904297,
"text": "_"
},
{
"id": 6009,
"logprob": -2.25,
"logprob": -1.6386719,
"text": "mean"
},
{
"id": 26,
"logprob": -0.30126953,
"logprob": -0.22705078,
"text": "("
},
{
"id": 62,
"logprob": -5.7539062,
"logprob": -5.234375,
"text": "L"
},
{
"id": 44,
"logprob": -3.0878906,
"logprob": -3.1132812,
"text": ":"
},
{
"id": 1682,
"logprob": -0.6845703,
"logprob": -1.1074219,
"text": " List"
},
{
"id": 77,
"logprob": -0.3918457,
"logprob": -0.14477539,
"text": "["
},
{
"id": 1808,
"logprob": -0.8798828,
"logprob": -0.3256836,
"text": "float"
},
{
"id": 10794,
"logprob": -2.4980469,
"logprob": -2.8027344,
"text": "]):"
}
],
@ -469,66 +472,67 @@
"tokens": [
{
"id": 284,
"logprob": -1.1533203,
"logprob": -0.12915039,
"special": false,
"text": "\n "
},
{
"id": 442,
"logprob": -0.91259766,
"id": 1524,
"logprob": -0.98535156,
"special": false,
"text": " return"
"text": " \"\"\""
},
{
"id": 3632,
"logprob": -1.3251953,
"id": 284,
"logprob": -0.69921875,
"special": false,
"text": " sum"
"text": "\n "
},
{
"id": 26,
"logprob": -0.08062744,
"id": 14883,
"logprob": -2.2011719,
"special": false,
"text": "("
"text": " Calculate"
},
{
"id": 62,
"logprob": -0.09906006,
"id": 322,
"logprob": -0.26708984,
"special": false,
"text": "L"
"text": " the"
},
{
"id": 27,
"logprob": -0.28979492,
"id": 3226,
"logprob": -0.08502197,
"special": false,
"text": ")"
"text": " ge"
},
{
"id": 517,
"logprob": -0.35958984,
"id": 21017,
"logprob": -0.019012451,
"special": false,
"text": " /"
"text": "ometric"
},
{
"id": 2069,
"logprob": -0.038604736,
"id": 5651,
"logprob": -0.028625488,
"special": false,
"text": " len"
"text": " mean"
},
{
"id": 26,
"logprob": -0.0011901855,
"id": 432,
"logprob": -0.29589844,
"special": false,
"text": "("
"text": " of"
},
{
"id": 62,
"logprob": -0.0005078316,
"id": 312,
"logprob": -0.31591797,
"special": false,
"text": "L"
"text": " a"
}
]
],
"top_tokens": null
},
"generated_text": "\n return sum(L) / len(L"
"generated_text": "\n \"\"\"\n Calculate the geometric mean of a"
}
]

View File

@ -85,7 +85,7 @@ __global__ void q4_matmul_kernel
if constexpr (use_half2)
{
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
@ -93,7 +93,7 @@ __global__ void q4_matmul_kernel
else
{
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
@ -110,7 +110,7 @@ __global__ void q4_matmul_kernel
{
int group = k / groupsize;
half2 w_scale = w_scales_.item_half2half2(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
@ -119,7 +119,7 @@ __global__ void q4_matmul_kernel
{
int group = k / groupsize;
half w_scale = w_scales_.item(group, w_column);
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);

View File

@ -189,7 +189,7 @@ __global__ void reconstruct_kernel
int group = row / groupsize;
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;
uint32_t w_read = w_.item_uint32_t(row, column);
half* out_ptr = out_.item_ptr(row, column);

View File

@ -152,10 +152,10 @@ __global__ void gemm_half_q_half_gptq_kernel
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
// __syncthreads();
@ -174,10 +174,10 @@ __global__ void gemm_half_q_half_gptq_kernel
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
}
#pragma unroll

View File

@ -237,10 +237,10 @@ __global__ void reconstruct_gptq_kernel
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
__syncthreads();
@ -255,10 +255,10 @@ __global__ void reconstruct_gptq_kernel
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++)

View File

@ -69,9 +69,17 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
bits, groupsize, _ = weights._get_gptq_params()
bits, groupsize, _, quant_method, = weights._get_gptq_params()
if quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif quant_method == "awq":
g_idx = None
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
from text_generation_server.utils.layers import HAS_EXLLAMA

View File

@ -0,0 +1,97 @@
import torch
from typing import List
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, 32, 4, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = AWQ_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is AWQ_PACK_ORDER
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)
return imatrix
def fast_awq_to_gptq(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros

View File

@ -182,7 +182,7 @@ try:
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated

View File

@ -349,6 +349,13 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
if IS_ROCM_SYSTEM:
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
if not HAS_AWQ:
raise NotImplementedError("You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly")
linear = WQLinear(
w_bit=bits,
group_size=groupsize,

View File

@ -46,7 +46,6 @@ class Weights:
return self._handles[filename]
def get_filename(self, tensor_name: str) -> (str, str):
names = [tensor_name]
if self.prefix is not None:
prefixed = f"{self.prefix}.{tensor_name}"
@ -154,15 +153,30 @@ class Weights:
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
bits, groupsize, _, quant_method = self._get_gptq_params()
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
if quantize == "gptq":
if quantize == "gptq" and quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
bits, groupsize, _ = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
@ -204,20 +218,40 @@ class Weights:
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if quantize == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
else:
g_idx = None
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
bits, groupsize, desc_act = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
)
if quantize == "gptq" and quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -243,7 +277,7 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize, desc_act = self._get_gptq_params()
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
if bits != 4:
use_exllama = False
@ -252,8 +286,19 @@ class Weights:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
if (
not torch.equal(
@ -269,13 +314,6 @@ class Weights:
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
if use_exllama:
@ -289,8 +327,6 @@ class Weights:
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
@ -298,12 +334,31 @@ class Weights:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
if use_exllama:
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // bits), device=qweight.device
)
// groupsize
).to(dtype=torch.int32)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq":
bits, groupsize, _ = self._get_gptq_params()
bits, groupsize, _, _ = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
@ -322,20 +377,22 @@ class Weights:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> Tuple[int, int, int]:
def _get_gptq_params(self) -> Tuple[int, int, int, str]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
desc_act = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
except Exception:
raise e
return bits, groupsize, desc_act
return bits, groupsize, desc_act, quant_method
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
@ -351,6 +408,7 @@ class Weights:
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
self.quant_method = data["quantization_config"]["quant_method"]
except Exception:
filename = "quantize_config.json"
try:
@ -365,6 +423,8 @@ class Weights:
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
@ -379,5 +439,7 @@ class Weights:
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
pass