Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758)
This change adds support for wNa16 int checkpoints with 2:4 sparsity using Marlin 2:4 kernels.
This commit is contained in:
parent
2fda8845a7
commit
46a5a7e73e
|
@ -0,0 +1,104 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5390625,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.86035156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4912109,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -2.1152344,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -1.765625,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.023864746,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.1060791,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.1940918,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -0.79785156,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.008262634,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.046569824,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0023479462,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -0.7626953,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5829,
|
||||
"logprob": -1.0107422,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5390625,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.86035156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4912109,
|
||||
"text": " learning"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 5380,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "?\n"
|
||||
},
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 320,
|
||||
"logprob": -0.19580078,
|
||||
"special": false,
|
||||
"text": " ("
|
||||
},
|
||||
{
|
||||
"id": 16931,
|
||||
"logprob": -1.7783203,
|
||||
"special": false,
|
||||
"text": "DL"
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ")"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -1.4287109,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "What is deep learning?\nDeep learning (DL) is a subset of"
|
||||
}
|
|
@ -0,0 +1,418 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5390625,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.86035156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4912109,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -2.1152344,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -1.765625,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.024002075,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.10760498,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.19580078,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -0.7993164,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.008300781,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.046295166,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.002374649,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -0.7651367,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5829,
|
||||
"logprob": -1.0107422,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5351562,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.85791016,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4882812,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -2.1210938,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -1.7597656,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.024032593,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.10748291,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.19592285,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -0.7988281,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.008354187,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.046569824,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0023517609,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -0.7661133,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5829,
|
||||
"logprob": -1.0107422,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5351562,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.85791016,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4882812,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -2.1210938,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -1.7597656,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.024032593,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.10748291,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.19592285,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -0.7988281,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.008354187,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.046569824,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0023517609,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -0.7661133,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5829,
|
||||
"logprob": -1.0107422,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 128000,
|
||||
"logprob": null,
|
||||
"text": "<|begin_of_text|>"
|
||||
},
|
||||
{
|
||||
"id": 3923,
|
||||
"logprob": -7.5351562,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.85791016,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 5655,
|
||||
"logprob": -8.828125,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -1.4882812,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -2.1210938,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 34564,
|
||||
"logprob": -1.7597656,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.024032593,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 374,
|
||||
"logprob": -0.10748291,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.19592285,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 27084,
|
||||
"logprob": -0.7988281,
|
||||
"special": false,
|
||||
"text": " subset"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.008354187,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 5780,
|
||||
"logprob": -0.046569824,
|
||||
"special": false,
|
||||
"text": " machine"
|
||||
},
|
||||
{
|
||||
"id": 6975,
|
||||
"logprob": -0.0023517609,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 430,
|
||||
"logprob": -0.7661133,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 5829,
|
||||
"logprob": -1.0107422,
|
||||
"special": false,
|
||||
"text": " uses"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,90 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def compressed_tensors_wna16_int_24_handle(launcher):
|
||||
with launcher(
|
||||
"danieldk/Llama-3.1-8B-w4a16-int-24",
|
||||
num_shard=2,
|
||||
quantize="compressed-tensors",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def compressed_tensors_wna16_int_24(compressed_tensors_wna16_int_24_handle):
|
||||
await compressed_tensors_wna16_int_24_handle.health(300)
|
||||
return compressed_tensors_wna16_int_24_handle.client
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_compressed_tensors_wna16_int_24(
|
||||
compressed_tensors_wna16_int_24, response_snapshot
|
||||
):
|
||||
response = await compressed_tensors_wna16_int_24.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.generated_text
|
||||
== "Deep learning is a subset of machine learning that uses"
|
||||
)
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_compressed_tensors_wna16_int_24_all_params(
|
||||
compressed_tensors_wna16_int_24, response_snapshot
|
||||
):
|
||||
response = await compressed_tensors_wna16_int_24.generate(
|
||||
"What is deep learning",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert (
|
||||
response.generated_text
|
||||
== "What is deep learning?\nDeep learning (DL) is a subset of"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_compressed_tensors_wna16_int_24_load(
|
||||
compressed_tensors_wna16_int_24, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
compressed_tensors_wna16_int_24,
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert (
|
||||
responses[0].generated_text
|
||||
== "Deep learning is a subset of machine learning that uses"
|
||||
)
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -13,7 +13,10 @@ from torch import nn
|
|||
|
||||
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
|
||||
WNA16Int24Loader,
|
||||
)
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
|
@ -151,7 +154,14 @@ class CompressedTensorsLoader(WeightsLoader):
|
|||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
||||
return WNA16Loader(weights)
|
||||
return WNA16IntLoader(weights)
|
||||
elif (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
return WNA16Int24Loader(weights)
|
||||
elif (
|
||||
format
|
||||
in {
|
||||
|
|
|
@ -9,7 +9,7 @@ from text_generation_server.utils.log import log_once
|
|||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16Loader(WeightsLoader):
|
||||
class WNA16IntLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT compressed-tensors parameters.
|
||||
"""
|
||||
|
@ -22,7 +22,7 @@ class WNA16Loader(WeightsLoader):
|
|||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.weights.num_bits}8A16"
|
||||
quantization_type = f"W{self.weights.num_bits}A16"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16Int24Loader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_args: QuantizationArgs):
|
||||
super().__init__()
|
||||
|
||||
if weight_args.type != QuantizationType.INT:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} only supports wNa8 int checkpoints"
|
||||
)
|
||||
|
||||
if weight_args.strategy == "group" and weight_args.group_size is None:
|
||||
raise ValueError("`group_size` must be set when `actorder` is `group`")
|
||||
|
||||
self.bits = weight_args.num_bits
|
||||
self.group_size = weight_args.group_size
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.bits}A16 2:4 sparsity"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply without tensor paralllism.
|
||||
"""
|
||||
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
|
||||
meta = weights.get_tensor(f"{prefix}.meta")
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
weight_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
meta = weights.get_packed_sharded(
|
||||
f"{prefix}.meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scale_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
weight_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scale_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
|
||||
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
|
||||
if self.group_size is None:
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
else:
|
||||
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
|
||||
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
|
@ -34,7 +34,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
|
||||
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_tensor(f"{prefix}.B")
|
||||
|
@ -65,7 +67,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
|
@ -96,7 +100,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
|
@ -132,7 +138,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
|||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||
|
@ -247,15 +255,15 @@ class GPTQMarlin24Weight:
|
|||
bits: quantized weight size.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
B_meta: torch.Tensor
|
||||
s: torch.Tensor
|
||||
weight_packed: torch.Tensor
|
||||
meta: torch.Tensor
|
||||
scale_packed: torch.Tensor
|
||||
bits: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.B.dtype == torch.int32
|
||||
assert self.B_meta.dtype == torch.int16
|
||||
assert self.s.dtype == torch.float16
|
||||
assert self.weight_packed.dtype == torch.int32
|
||||
assert self.meta.dtype == torch.int16
|
||||
assert self.scale_packed.dtype == torch.float16
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlin24Linear(
|
||||
|
@ -279,9 +287,13 @@ class GPTQMarlin24Linear(nn.Module):
|
|||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.s.shape[1]
|
||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||
in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.scale_packed.shape[1]
|
||||
groupsize = (
|
||||
-1
|
||||
if weight.scale_packed.shape[0] == 1
|
||||
else in_features // weight.scale_packed.shape[0]
|
||||
)
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(
|
||||
|
@ -309,9 +321,9 @@ class GPTQMarlin24Linear(nn.Module):
|
|||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||
)
|
||||
|
||||
self.B = weight.B
|
||||
self.B_meta = weight.B_meta
|
||||
self.s = weight.s
|
||||
self.weight_packed = weight.weight_packed
|
||||
self.meta = weight.meta
|
||||
self.scale_packed = weight.scale_packed
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
|
@ -320,7 +332,7 @@ class GPTQMarlin24Linear(nn.Module):
|
|||
self.workspace = torch.zeros(
|
||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
dtype=torch.int,
|
||||
device=weight.B.device,
|
||||
device=weight.weight_packed.device,
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -328,17 +340,17 @@ class GPTQMarlin24Linear(nn.Module):
|
|||
|
||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.B_meta,
|
||||
self.s,
|
||||
self.weight_packed,
|
||||
self.meta,
|
||||
self.scale_packed,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
A.shape[0],
|
||||
self.s.shape[1],
|
||||
self.scale_packed.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
|
||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||
C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
|
Loading…
Reference in New Issue