Add support for wNa16 int 2:4 compressed-tensors checkpoints (#2758)
This change adds support for wNa16 int checkpoints with 2:4 sparsity using Marlin 2:4 kernels.
This commit is contained in:
parent
2fda8845a7
commit
46a5a7e73e
|
@ -0,0 +1,104 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5390625,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.86035156,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4912109,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -2.1152344,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": -1.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.023864746,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.1060791,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.1940918,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -0.79785156,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.008262634,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.046569824,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0023479462,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -0.7626953,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5829,
|
||||||
|
"logprob": -1.0107422,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5390625,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.86035156,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4912109,
|
||||||
|
"text": " learning"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 5380,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "?\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 320,
|
||||||
|
"logprob": -0.19580078,
|
||||||
|
"special": false,
|
||||||
|
"text": " ("
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16931,
|
||||||
|
"logprob": -1.7783203,
|
||||||
|
"special": false,
|
||||||
|
"text": "DL"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": ")"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -1.4287109,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "What is deep learning?\nDeep learning (DL) is a subset of"
|
||||||
|
}
|
|
@ -0,0 +1,418 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5390625,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.86035156,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4912109,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -2.1152344,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": -1.765625,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.024002075,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.10760498,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.19580078,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -0.7993164,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.008300781,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.046295166,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.002374649,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -0.7651367,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5829,
|
||||||
|
"logprob": -1.0107422,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5351562,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.85791016,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4882812,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -2.1210938,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": -1.7597656,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.024032593,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.10748291,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.19592285,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -0.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.008354187,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.046569824,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0023517609,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -0.7661133,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5829,
|
||||||
|
"logprob": -1.0107422,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5351562,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.85791016,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4882812,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -2.1210938,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": -1.7597656,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.024032593,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.10748291,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.19592285,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -0.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.008354187,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.046569824,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0023517609,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -0.7661133,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5829,
|
||||||
|
"logprob": -1.0107422,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 128000,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<|begin_of_text|>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3923,
|
||||||
|
"logprob": -7.5351562,
|
||||||
|
"text": "What"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.85791016,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5655,
|
||||||
|
"logprob": -8.828125,
|
||||||
|
"text": " deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -1.4882812,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -2.1210938,
|
||||||
|
"text": "?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 34564,
|
||||||
|
"logprob": -1.7597656,
|
||||||
|
"special": false,
|
||||||
|
"text": "Deep"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.024032593,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.10748291,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.19592285,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 27084,
|
||||||
|
"logprob": -0.7988281,
|
||||||
|
"special": false,
|
||||||
|
"text": " subset"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 315,
|
||||||
|
"logprob": -0.008354187,
|
||||||
|
"special": false,
|
||||||
|
"text": " of"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5780,
|
||||||
|
"logprob": -0.046569824,
|
||||||
|
"special": false,
|
||||||
|
"text": " machine"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6975,
|
||||||
|
"logprob": -0.0023517609,
|
||||||
|
"special": false,
|
||||||
|
"text": " learning"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 430,
|
||||||
|
"logprob": -0.7661133,
|
||||||
|
"special": false,
|
||||||
|
"text": " that"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5829,
|
||||||
|
"logprob": -1.0107422,
|
||||||
|
"special": false,
|
||||||
|
"text": " uses"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Deep learning is a subset of machine learning that uses"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,90 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def compressed_tensors_wna16_int_24_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"danieldk/Llama-3.1-8B-w4a16-int-24",
|
||||||
|
num_shard=2,
|
||||||
|
quantize="compressed-tensors",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def compressed_tensors_wna16_int_24(compressed_tensors_wna16_int_24_handle):
|
||||||
|
await compressed_tensors_wna16_int_24_handle.health(300)
|
||||||
|
return compressed_tensors_wna16_int_24_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_compressed_tensors_wna16_int_24(
|
||||||
|
compressed_tensors_wna16_int_24, response_snapshot
|
||||||
|
):
|
||||||
|
response = await compressed_tensors_wna16_int_24.generate(
|
||||||
|
"What is deep learning?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "Deep learning is a subset of machine learning that uses"
|
||||||
|
)
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_compressed_tensors_wna16_int_24_all_params(
|
||||||
|
compressed_tensors_wna16_int_24, response_snapshot
|
||||||
|
):
|
||||||
|
response = await compressed_tensors_wna16_int_24.generate(
|
||||||
|
"What is deep learning",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "What is deep learning?\nDeep learning (DL) is a subset of"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_compressed_tensors_wna16_int_24_load(
|
||||||
|
compressed_tensors_wna16_int_24, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
compressed_tensors_wna16_int_24,
|
||||||
|
"What is deep learning?",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
responses[0].generated_text
|
||||||
|
== "Deep learning is a subset of machine learning that uses"
|
||||||
|
)
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -13,7 +13,10 @@ from torch import nn
|
||||||
|
|
||||||
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||||
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
||||||
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16Loader
|
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
|
||||||
|
WNA16Int24Loader,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
|
@ -151,7 +154,14 @@ class CompressedTensorsLoader(WeightsLoader):
|
||||||
and weights.num_bits in (4, 8)
|
and weights.num_bits in (4, 8)
|
||||||
):
|
):
|
||||||
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
||||||
return WNA16Loader(weights)
|
return WNA16IntLoader(weights)
|
||||||
|
elif (
|
||||||
|
format == CompressionFormat.marlin_24.value
|
||||||
|
and weights is not None
|
||||||
|
and weights.type == QuantizationType.INT
|
||||||
|
and weights.num_bits in (4, 8)
|
||||||
|
):
|
||||||
|
return WNA16Int24Loader(weights)
|
||||||
elif (
|
elif (
|
||||||
format
|
format
|
||||||
in {
|
in {
|
||||||
|
|
|
@ -9,7 +9,7 @@ from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
class WNA16Loader(WeightsLoader):
|
class WNA16IntLoader(WeightsLoader):
|
||||||
"""
|
"""
|
||||||
Loader for W4A16/W8A16 INT compressed-tensors parameters.
|
Loader for W4A16/W8A16 INT compressed-tensors parameters.
|
||||||
"""
|
"""
|
||||||
|
@ -22,7 +22,7 @@ class WNA16Loader(WeightsLoader):
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
quantization_type = f"W{self.weights.num_bits}8A16"
|
quantization_type = f"W{self.weights.num_bits}A16"
|
||||||
|
|
||||||
return f"{self.__class__.__name__} ({quantization_type})"
|
return f"{self.__class__.__name__} ({quantization_type})"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
|
||||||
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
|
||||||
|
|
||||||
|
class WNA16Int24Loader(WeightsLoader):
|
||||||
|
"""
|
||||||
|
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight_args: QuantizationArgs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if weight_args.type != QuantizationType.INT:
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self).__name__} only supports wNa8 int checkpoints"
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight_args.strategy == "group" and weight_args.group_size is None:
|
||||||
|
raise ValueError("`group_size` must be set when `actorder` is `group`")
|
||||||
|
|
||||||
|
self.bits = weight_args.num_bits
|
||||||
|
self.group_size = weight_args.group_size
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
quantization_type = f"W{self.bits}A16 2:4 sparsity"
|
||||||
|
|
||||||
|
return f"{self.__class__.__name__} ({quantization_type})"
|
||||||
|
|
||||||
|
def get_weights(self, weights: Weights, prefix: str):
|
||||||
|
"""
|
||||||
|
Get weights at the given prefix and apply without tensor paralllism.
|
||||||
|
"""
|
||||||
|
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
|
||||||
|
meta = weights.get_tensor(f"{prefix}.meta")
|
||||||
|
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||||
|
return GPTQMarlin24Weight(
|
||||||
|
weight_packed=weight_packed,
|
||||||
|
meta=meta,
|
||||||
|
scale_packed=scale_packed,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_col_packed(
|
||||||
|
self,
|
||||||
|
weights: Weights,
|
||||||
|
prefix: str,
|
||||||
|
block_sizes: Union[int, List[int]],
|
||||||
|
):
|
||||||
|
weight_packed = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
meta = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.meta", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
scale_packed = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
return GPTQMarlin24Weight(
|
||||||
|
weight_packed=weight_packed,
|
||||||
|
meta=meta,
|
||||||
|
scale_packed=scale_packed,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||||
|
weight_packed = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
meta = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
scale_packed = torch.cat(
|
||||||
|
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
|
||||||
|
)
|
||||||
|
return GPTQMarlin24Weight(
|
||||||
|
weight_packed=weight_packed,
|
||||||
|
meta=meta,
|
||||||
|
scale_packed=scale_packed,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weights_row(self, weights: Weights, prefix: str):
|
||||||
|
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
|
||||||
|
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
|
||||||
|
if self.group_size is None:
|
||||||
|
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||||
|
else:
|
||||||
|
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
|
||||||
|
|
||||||
|
return GPTQMarlin24Weight(
|
||||||
|
weight_packed=weight_packed,
|
||||||
|
meta=meta,
|
||||||
|
scale_packed=scale_packed,
|
||||||
|
bits=self.bits,
|
||||||
|
)
|
|
@ -34,7 +34,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||||
|
|
||||||
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
||||||
s = weights.get_tensor(f"{prefix}.s")
|
s = weights.get_tensor(f"{prefix}.s")
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
weight = GPTQMarlin24Weight(
|
||||||
|
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = weights.get_tensor(f"{prefix}.B")
|
B = weights.get_tensor(f"{prefix}.B")
|
||||||
|
@ -65,7 +67,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||||
)
|
)
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
weight = GPTQMarlin24Weight(
|
||||||
|
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
B = weights.get_packed_sharded(
|
B = weights.get_packed_sharded(
|
||||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||||
|
@ -96,7 +100,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
weight = GPTQMarlin24Weight(
|
||||||
|
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = torch.cat(
|
B = torch.cat(
|
||||||
|
@ -132,7 +138,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||||
else:
|
else:
|
||||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||||
|
|
||||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
weight = GPTQMarlin24Weight(
|
||||||
|
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||||
|
@ -247,15 +255,15 @@ class GPTQMarlin24Weight:
|
||||||
bits: quantized weight size.
|
bits: quantized weight size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
B: torch.Tensor
|
weight_packed: torch.Tensor
|
||||||
B_meta: torch.Tensor
|
meta: torch.Tensor
|
||||||
s: torch.Tensor
|
scale_packed: torch.Tensor
|
||||||
bits: int
|
bits: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.B.dtype == torch.int32
|
assert self.weight_packed.dtype == torch.int32
|
||||||
assert self.B_meta.dtype == torch.int16
|
assert self.meta.dtype == torch.int16
|
||||||
assert self.s.dtype == torch.float16
|
assert self.scale_packed.dtype == torch.float16
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return GPTQMarlin24Linear(
|
return GPTQMarlin24Linear(
|
||||||
|
@ -279,9 +287,13 @@ class GPTQMarlin24Linear(nn.Module):
|
||||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||||
)
|
)
|
||||||
|
|
||||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
|
||||||
out_features = weight.s.shape[1]
|
out_features = weight.scale_packed.shape[1]
|
||||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
groupsize = (
|
||||||
|
-1
|
||||||
|
if weight.scale_packed.shape[0] == 1
|
||||||
|
else in_features // weight.scale_packed.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||||
supported_sizes = ", ".join(
|
supported_sizes = ", ".join(
|
||||||
|
@ -309,9 +321,9 @@ class GPTQMarlin24Linear(nn.Module):
|
||||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.B = weight.B
|
self.weight_packed = weight.weight_packed
|
||||||
self.B_meta = weight.B_meta
|
self.meta = weight.meta
|
||||||
self.s = weight.s
|
self.scale_packed = weight.scale_packed
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
else:
|
else:
|
||||||
|
@ -320,7 +332,7 @@ class GPTQMarlin24Linear(nn.Module):
|
||||||
self.workspace = torch.zeros(
|
self.workspace = torch.zeros(
|
||||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=weight.B.device,
|
device=weight.weight_packed.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -328,17 +340,17 @@ class GPTQMarlin24Linear(nn.Module):
|
||||||
|
|
||||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||||
A.view(-1, A.shape[-1]),
|
A.view(-1, A.shape[-1]),
|
||||||
self.B,
|
self.weight_packed,
|
||||||
self.B_meta,
|
self.meta,
|
||||||
self.s,
|
self.scale_packed,
|
||||||
self.workspace,
|
self.workspace,
|
||||||
self.bits,
|
self.bits,
|
||||||
A.shape[0],
|
A.shape[0],
|
||||||
self.s.shape[1],
|
self.scale_packed.shape[1],
|
||||||
A.shape[1],
|
A.shape[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
C += self.bias
|
C += self.bias
|
||||||
|
|
Loading…
Reference in New Issue