Use GPTQ-Marlin for supported GPTQ configurations (#2111)
GPTQ-Marlin is currently the best-performing kernel for GPTQ models. So let's use it by default if the kernels are installed, the GPU supports it, and the kernels support the configuration. For models generated by `text-generation-server quantize`, use `sym=False`. This subcommand symmetric quantization since the beginning and incorrectly reporting the model to be symmetric will use GPTQ-Marlin (which does not support asymmetric quantization).
This commit is contained in:
parent
0d97a93c1e
commit
2ce8019480
|
@ -1,84 +0,0 @@
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6230469,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.046875,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1425781,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.9238281,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.076660156,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10821533,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
}
|
|
|
@ -1,84 +0,0 @@
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": 0,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 13,
|
|
||||||
"logprob": -2.2539062,
|
|
||||||
"special": false,
|
|
||||||
"text": "."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 578,
|
|
||||||
"logprob": -0.15563965,
|
|
||||||
"special": false,
|
|
||||||
"text": " The"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3622,
|
|
||||||
"logprob": -0.8203125,
|
|
||||||
"special": false,
|
|
||||||
"text": " server"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 706,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " has"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 539,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " not"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3686,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " yet"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3288,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " sent"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 904,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " any"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 828,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 382,
|
|
||||||
"logprob": -1.5517578,
|
|
||||||
"special": false,
|
|
||||||
"text": ".\n\n"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
|
||||||
}
|
|
|
@ -1,338 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"details": {
|
|
||||||
"best_of_sequences": null,
|
|
||||||
"finish_reason": "length",
|
|
||||||
"generated_tokens": 10,
|
|
||||||
"prefill": [
|
|
||||||
{
|
|
||||||
"id": 2323,
|
|
||||||
"logprob": null,
|
|
||||||
"text": "Test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -11.34375,
|
|
||||||
"text": " request"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"seed": null,
|
|
||||||
"tokens": [
|
|
||||||
{
|
|
||||||
"id": 198,
|
|
||||||
"logprob": -2.5742188,
|
|
||||||
"special": false,
|
|
||||||
"text": "\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -1.6220703,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3270,
|
|
||||||
"logprob": -2.0410156,
|
|
||||||
"special": false,
|
|
||||||
"text": " \"\"\"\n"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 262,
|
|
||||||
"logprob": -0.015281677,
|
|
||||||
"special": false,
|
|
||||||
"text": " "
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 422,
|
|
||||||
"logprob": -2.1445312,
|
|
||||||
"special": false,
|
|
||||||
"text": " if"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1715,
|
|
||||||
"logprob": -0.92333984,
|
|
||||||
"special": false,
|
|
||||||
"text": " request"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 13204,
|
|
||||||
"logprob": -0.07672119,
|
|
||||||
"special": false,
|
|
||||||
"text": ".method"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 624,
|
|
||||||
"logprob": -0.021987915,
|
|
||||||
"special": false,
|
|
||||||
"text": " =="
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 364,
|
|
||||||
"logprob": -0.39208984,
|
|
||||||
"special": false,
|
|
||||||
"text": " '"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3019,
|
|
||||||
"logprob": -0.10638428,
|
|
||||||
"special": false,
|
|
||||||
"text": "POST"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"top_tokens": null
|
|
||||||
},
|
|
||||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -1,68 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def flash_llama_gptq_marlin_handle(launcher):
|
|
||||||
with launcher(
|
|
||||||
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
|
|
||||||
await flash_llama_gptq_marlin_handle.health(300)
|
|
||||||
return flash_llama_gptq_marlin_handle.client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
|
|
||||||
response = await flash_llama_gptq_marlin.generate(
|
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin_all_params(
|
|
||||||
flash_llama_gptq_marlin, response_snapshot
|
|
||||||
):
|
|
||||||
response = await flash_llama_gptq_marlin.generate(
|
|
||||||
"Test request",
|
|
||||||
max_new_tokens=10,
|
|
||||||
repetition_penalty=1.2,
|
|
||||||
return_full_text=True,
|
|
||||||
temperature=0.5,
|
|
||||||
top_p=0.9,
|
|
||||||
top_k=10,
|
|
||||||
truncate=5,
|
|
||||||
typical_p=0.9,
|
|
||||||
watermark=True,
|
|
||||||
decoder_input_details=True,
|
|
||||||
seed=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
|
||||||
assert response == response_snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_gptq_marlin_load(
|
|
||||||
flash_llama_gptq_marlin, generate_load, response_snapshot
|
|
||||||
):
|
|
||||||
responses = await generate_load(
|
|
||||||
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(responses) == 4
|
|
||||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
|
||||||
|
|
||||||
assert responses == response_snapshot
|
|
|
@ -7,6 +7,16 @@ from text_generation_server.utils.import_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQParams:
|
||||||
|
bits: int
|
||||||
|
checkpoint_format: Optional[str]
|
||||||
|
groupsize: int
|
||||||
|
desc_act: bool
|
||||||
|
quant_method: str
|
||||||
|
sym: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTQWeight:
|
class GPTQWeight:
|
||||||
qweight: torch.Tensor
|
qweight: torch.Tensor
|
||||||
|
|
|
@ -166,12 +166,17 @@ def get_linear(weight, bias, quantize):
|
||||||
|
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
if not isinstance(weight, GPTQWeight):
|
GPTQMarlinLinear,
|
||||||
raise NotImplementedError(
|
GPTQMarlinWeight,
|
||||||
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(weight, GPTQMarlinWeight):
|
||||||
|
linear = GPTQMarlinLinear(
|
||||||
|
weight=weight,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
elif isinstance(weight, GPTQWeight):
|
||||||
if weight.use_exllama:
|
if weight.use_exllama:
|
||||||
try:
|
try:
|
||||||
from text_generation_server.layers.gptq import (
|
from text_generation_server.layers.gptq import (
|
||||||
|
@ -195,6 +200,11 @@ def get_linear(weight, bias, quantize):
|
||||||
weight.bits,
|
weight.bits,
|
||||||
weight.groupsize,
|
weight.groupsize,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
||||||
|
)
|
||||||
|
|
||||||
elif quantize == "awq":
|
elif quantize == "awq":
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
|
||||||
|
@ -226,18 +236,11 @@ def get_linear(weight, bias, quantize):
|
||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Linear,
|
GPTQMarlin24Linear,
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
GPTQMarlinLinear,
|
|
||||||
GPTQMarlinWeight,
|
|
||||||
MarlinLinear,
|
MarlinLinear,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(weight, GPTQMarlinWeight):
|
if isinstance(weight, GPTQMarlin24Weight):
|
||||||
linear = GPTQMarlinLinear(
|
|
||||||
weight=weight,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
elif isinstance(weight, GPTQMarlin24Weight):
|
|
||||||
linear = GPTQMarlin24Linear(
|
linear = GPTQMarlin24Linear(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
|
|
@ -3,6 +3,8 @@ from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.layers.gptq import GPTQParams
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -22,6 +24,19 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
MARLIN_TILE_SIZE = 16
|
MARLIN_TILE_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||||
|
return (
|
||||||
|
SYSTEM == "cuda"
|
||||||
|
and marlin_kernels is not None
|
||||||
|
and has_sm_8_0
|
||||||
|
and quantize == "gptq"
|
||||||
|
and gptq_params.quant_method == "gptq"
|
||||||
|
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||||
|
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||||
|
and gptq_params.sym
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_marlin_kernels():
|
def _check_marlin_kernels():
|
||||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
if not (SYSTEM == "cuda" and has_sm_8_0):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
@ -1,25 +1,15 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from safetensors import safe_open, SafetensorError
|
from safetensors import safe_open, SafetensorError
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
|
from text_generation_server.layers.gptq import GPTQParams
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _GPTQParams:
|
|
||||||
bits: int
|
|
||||||
checkpoint_format: Optional[str]
|
|
||||||
groupsize: int
|
|
||||||
desc_act: bool
|
|
||||||
quant_method: str
|
|
||||||
sym: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -212,6 +202,10 @@ class Weights:
|
||||||
"""
|
"""
|
||||||
if quantize in ["gptq", "awq"]:
|
if quantize in ["gptq", "awq"]:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = self.get_packed_sharded(
|
qweight = self.get_packed_sharded(
|
||||||
|
@ -221,17 +215,28 @@ class Weights:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
)
|
)
|
||||||
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
|
|
||||||
qzeros = self.get_packed_sharded(
|
|
||||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
scales = self.get_packed_sharded(
|
scales = self.get_packed_sharded(
|
||||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||||
)
|
)
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
|
|
||||||
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = self.get_packed_sharded(
|
||||||
|
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||||
|
@ -269,7 +274,6 @@ class Weights:
|
||||||
repack_gptq_for_marlin,
|
repack_gptq_for_marlin,
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
B = self.get_packed_sharded(
|
B = self.get_packed_sharded(
|
||||||
|
@ -286,31 +290,6 @@ class Weights:
|
||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
try:
|
|
||||||
qweight = self.get_packed_sharded(
|
|
||||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = self.get_packed_sharded(
|
|
||||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
|
||||||
)
|
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
B = self.get_packed_sharded(
|
B = self.get_packed_sharded(
|
||||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||||
|
@ -356,6 +335,10 @@ class Weights:
|
||||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||||
elif quantize in ["gptq", "awq"]:
|
elif quantize in ["gptq", "awq"]:
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
|
from text_generation_server.layers.marlin import (
|
||||||
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qweight = torch.cat(
|
qweight = torch.cat(
|
||||||
|
@ -366,14 +349,31 @@ class Weights:
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||||
)
|
)
|
||||||
|
|
||||||
qzeros = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
scales = torch.cat(
|
scales = torch.cat(
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
gptq_params = self._get_gptq_params()
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||||
|
for w2 in w[1:]:
|
||||||
|
torch.testing.assert_close(w2, w[0])
|
||||||
|
g_idx = w[0]
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
qzeros = torch.cat(
|
||||||
|
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||||
|
|
||||||
|
@ -425,10 +425,8 @@ class Weights:
|
||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
try:
|
try:
|
||||||
|
@ -452,36 +450,6 @@ class Weights:
|
||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
try:
|
|
||||||
qweight = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
scales = torch.cat(
|
|
||||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
|
||||||
)
|
|
||||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
|
||||||
for w2 in w[1:]:
|
|
||||||
torch.testing.assert_close(w2, w[0])
|
|
||||||
g_idx = w[0]
|
|
||||||
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = torch.cat(
|
B = torch.cat(
|
||||||
|
@ -544,9 +512,41 @@ class Weights:
|
||||||
)
|
)
|
||||||
|
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
use_exllama = True
|
from text_generation_server.layers.marlin import (
|
||||||
gptq_params = self._get_gptq_params()
|
can_use_gptq_marlin,
|
||||||
|
repack_gptq_for_marlin,
|
||||||
|
)
|
||||||
|
|
||||||
|
gptq_params = self._get_gptq_params()
|
||||||
|
if can_use_gptq_marlin(gptq_params, quantize):
|
||||||
|
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||||
|
try:
|
||||||
|
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||||
|
except RuntimeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||||
|
)
|
||||||
|
|
||||||
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||||
|
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||||
|
scales = self.get_tensor(f"{prefix}.scales")
|
||||||
|
else:
|
||||||
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
|
|
||||||
|
sharded_in_features = self.process_group.size() > 1
|
||||||
|
|
||||||
|
return repack_gptq_for_marlin(
|
||||||
|
qweight=qweight,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
bits=gptq_params.bits,
|
||||||
|
desc_act=gptq_params.desc_act,
|
||||||
|
groupsize=gptq_params.groupsize,
|
||||||
|
sym=gptq_params.sym,
|
||||||
|
sharded_infeatures=sharded_in_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_exllama = True
|
||||||
if gptq_params.bits != 4:
|
if gptq_params.bits != 4:
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
|
@ -672,10 +672,8 @@ class Weights:
|
||||||
from text_generation_server.layers.marlin import (
|
from text_generation_server.layers.marlin import (
|
||||||
GPTQMarlin24Weight,
|
GPTQMarlin24Weight,
|
||||||
MarlinWeight,
|
MarlinWeight,
|
||||||
repack_gptq_for_marlin,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
quant_method = getattr(self, "quant_method", "marlin")
|
|
||||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||||
if is_marlin_24:
|
if is_marlin_24:
|
||||||
try:
|
try:
|
||||||
|
@ -698,35 +696,6 @@ class Weights:
|
||||||
weight = GPTQMarlin24Weight(
|
weight = GPTQMarlin24Weight(
|
||||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||||
)
|
)
|
||||||
elif quant_method == "gptq":
|
|
||||||
log_once(logger.info, "Converting GPTQ model to Marlin packing format.")
|
|
||||||
gptq_params = self._get_gptq_params()
|
|
||||||
|
|
||||||
try:
|
|
||||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
|
||||||
except RuntimeError:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
|
||||||
)
|
|
||||||
|
|
||||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
|
||||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
|
||||||
else:
|
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
|
||||||
|
|
||||||
sharded_in_features = self.process_group.size() > 1
|
|
||||||
|
|
||||||
weight = repack_gptq_for_marlin(
|
|
||||||
qweight=qweight,
|
|
||||||
scales=scales,
|
|
||||||
g_idx=g_idx,
|
|
||||||
bits=gptq_params.bits,
|
|
||||||
desc_act=gptq_params.desc_act,
|
|
||||||
groupsize=gptq_params.groupsize,
|
|
||||||
sym=gptq_params.sym,
|
|
||||||
sharded_infeatures=sharded_in_features,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||||
|
@ -743,18 +712,17 @@ class Weights:
|
||||||
else:
|
else:
|
||||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||||
weight = MarlinWeight(B=B, s=s)
|
weight = MarlinWeight(B=B, s=s)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def _get_gptq_params(self) -> _GPTQParams:
|
def _get_gptq_params(self) -> GPTQParams:
|
||||||
try:
|
try:
|
||||||
bits = self.get_tensor("gptq_bits").item()
|
bits = self.get_tensor("gptq_bits").item()
|
||||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||||
desc_act = False
|
desc_act = False
|
||||||
sym = True
|
sym = False
|
||||||
quant_method = "gptq"
|
quant_method = "gptq"
|
||||||
except (SafetensorError, RuntimeError) as e:
|
except (SafetensorError, RuntimeError) as e:
|
||||||
try:
|
try:
|
||||||
|
@ -767,7 +735,7 @@ class Weights:
|
||||||
except Exception:
|
except Exception:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return _GPTQParams(
|
return GPTQParams(
|
||||||
bits=bits,
|
bits=bits,
|
||||||
checkpoint_format=checkpoint_format,
|
checkpoint_format=checkpoint_format,
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
|
|
Loading…
Reference in New Issue