Add support for GPTQ-quantized MoE models using MoE Marlin (#2557)
This change add support for MoE models that use GPTQ quantization. Currently only models with the following properties are supported: - No `desc_act` with tensor parallelism, unless `group_size=-1`. - No asymmetric quantization. - No AWQ.
This commit is contained in:
parent
f9e561eced
commit
90a1d04a2f
15
flake.lock
15
flake.lock
|
@ -978,16 +978,17 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1727353315,
|
"lastModified": 1727681169,
|
||||||
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=",
|
"narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=",
|
||||||
"owner": "huggingface",
|
"owner": "danieldk",
|
||||||
"repo": "text-generation-inference-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "1d42c4125ebafb87707118168995675cc5050b9d",
|
"rev": "88ba4cfe378d8fb08222f640ff2b62ac0ee6569d",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "huggingface",
|
"owner": "danieldk",
|
||||||
"repo": "text-generation-inference-nix",
|
"ref": "moe-kernels-0.4.0",
|
||||||
|
"repo": "tgi-nix",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
};
|
};
|
||||||
nix-filter.url = "github:numtide/nix-filter";
|
nix-filter.url = "github:numtide/nix-filter";
|
||||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
|
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.4.0";
|
||||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||||
flake-utils.url = "github:numtide/flake-utils";
|
flake-utils.url = "github:numtide/flake-utils";
|
||||||
rust-overlay = {
|
rust-overlay = {
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8808594,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37280273,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.26098633,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017137527,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2695312,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9238281,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
}
|
|
@ -0,0 +1,89 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": 0,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.34838867,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13940,
|
||||||
|
"logprob": -0.38916016,
|
||||||
|
"special": false,
|
||||||
|
"text": "``"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28832,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "`"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3371,
|
||||||
|
"logprob": -1.2529297,
|
||||||
|
"special": false,
|
||||||
|
"text": "json"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28751,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "{"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2287,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " "
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 345,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": " \""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3134,
|
||||||
|
"logprob": -0.640625,
|
||||||
|
"special": false,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "Test request\n```json\n{\n \"request"
|
||||||
|
}
|
|
@ -0,0 +1,358 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -11.0078125,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -13.59375,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.7089844,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.68847656,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 28771,
|
||||||
|
"logprob": -1.9394531,
|
||||||
|
"special": false,
|
||||||
|
"text": "#"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -2.8828125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.37329102,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.2602539,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.0017185211,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -2.2753906,
|
||||||
|
"special": false,
|
||||||
|
"text": "##"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3735,
|
||||||
|
"logprob": -1.9316406,
|
||||||
|
"special": false,
|
||||||
|
"text": " Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2159,
|
||||||
|
"logprob": -0.48217773,
|
||||||
|
"special": false,
|
||||||
|
"text": " request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\n# Test request\n\n## Test request"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,60 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_mixtral_gptq_handle(launcher):
|
||||||
|
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
|
||||||
|
await flash_mixtral_gptq_handle.health(300)
|
||||||
|
return flash_mixtral_gptq_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
|
||||||
|
response = await flash_mixtral_gptq.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
|
||||||
|
response = await flash_mixtral_gptq.generate(
|
||||||
|
"Test request",
|
||||||
|
max_new_tokens=10,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
return_full_text=True,
|
||||||
|
stop_sequences=["test"],
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.9,
|
||||||
|
top_k=10,
|
||||||
|
truncate=5,
|
||||||
|
typical_p=0.9,
|
||||||
|
watermark=True,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flash_mixtral_gptq_load(
|
||||||
|
flash_mixtral_gptq, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert all(
|
||||||
|
[r.generated_text == responses[0].generated_text for r in responses]
|
||||||
|
), f"{[r.generated_text for r in responses]}"
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -1244,12 +1244,12 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.3.1"
|
version = "0.4.0"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:b679984a53807127f25af053ec0a2c07dec97ec196f76363a8bfdc3fbb3d1a9a"},
|
{file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1259,16 +1259,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.3.1"
|
version = "0.4.0"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:29684f81495f6e032085295c86d160022f03d5d9a9981446f311ca94fbbbc2cd"},
|
{file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1278,16 +1278,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.3.1"
|
version = "0.4.0"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9dfdbef48b5b7e97912aaa7420b1b694876a3281f5edfe7d4ca9a69e1f48bff2"},
|
{file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1297,16 +1297,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.3.1"
|
version = "0.4.0"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:f7d0fc8f191c905a668f3d2eb889999ee988048d08bfd7062d64bca3876588ae"},
|
{file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1316,7 +1316,7 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
|
|
|
@ -47,10 +47,10 @@ marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
moe-kernels = [
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
rich = "^13.7.1"
|
rich = "^13.7.1"
|
||||||
|
|
||||||
|
|
|
@ -10,13 +10,18 @@ from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||||
|
from text_generation_server.layers.moe.gptq_marlin import (
|
||||||
|
GPTQMarlinSparseMoELayer,
|
||||||
|
can_use_marlin_moe_gemm,
|
||||||
|
)
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.log import log_once
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
UnquantizedWeight,
|
|
||||||
Weights,
|
Weights,
|
||||||
|
UnquantizedWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
|
@ -205,14 +210,18 @@ class SparseMoELayer(nn.Module):
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||||
cls = UnquantizedSparseMoELayer
|
cls = UnquantizedSparseMoELayer
|
||||||
# Once we wire up GPTQ-Marlin MoE:
|
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||||
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
cls = GPTQMarlinSparseMoELayer
|
||||||
# cls = GPTQMarlinSparseMoELayer
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"Using MoE layer wih fused gemm",
|
||||||
|
)
|
||||||
|
|
||||||
self.moe = cls(
|
self.moe = cls(
|
||||||
n_expert_group=n_expert_group,
|
n_expert_group=n_expert_group,
|
||||||
n_experts=n_experts,
|
n_experts=n_experts,
|
||||||
|
@ -237,6 +246,15 @@ class SparseMoELayer(nn.Module):
|
||||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||||
)
|
)
|
||||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||||
# Once we wire up GPTQ-Marlin MoE:
|
or (
|
||||||
# or isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||||
|
and can_use_marlin_moe_gemm(
|
||||||
|
desc_act=weights.loader.desc_act,
|
||||||
|
groupsize=weights.loader.groupsize,
|
||||||
|
quant_method=weights.loader.quant_method,
|
||||||
|
quantize=weights.loader.quantize,
|
||||||
|
sym=weights.loader.sym,
|
||||||
|
use_tp=weights.process_group.size() > 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,225 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.weights import Weights
|
||||||
|
from text_generation_server.layers.marlin.gptq import (
|
||||||
|
GPTQMarlinWeight,
|
||||||
|
GPTQMarlinWeightsLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||||
|
else:
|
||||||
|
fused_marlin_moe = None
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
major, _minor = torch.cuda.get_device_capability()
|
||||||
|
has_sm_8_0 = major >= 8
|
||||||
|
except Exception:
|
||||||
|
has_sm_8_0 = False
|
||||||
|
|
||||||
|
|
||||||
|
def can_use_marlin_moe_gemm(
|
||||||
|
*,
|
||||||
|
desc_act: bool,
|
||||||
|
groupsize: int,
|
||||||
|
quant_method: str,
|
||||||
|
quantize: str,
|
||||||
|
sym: bool,
|
||||||
|
use_tp: bool,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
SYSTEM == "cuda"
|
||||||
|
and fused_marlin_moe is not None
|
||||||
|
and has_sm_8_0
|
||||||
|
and quantize == "gptq"
|
||||||
|
and quant_method == "gptq"
|
||||||
|
and sym
|
||||||
|
and is_full_k(desc_act, groupsize, use_tp)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
|
||||||
|
if groupsize == -1:
|
||||||
|
return True
|
||||||
|
return not (desc_act and use_tp)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPTQMarlinMoEWeight:
|
||||||
|
qweight: torch.Tensor
|
||||||
|
qzeros: torch.Tensor
|
||||||
|
scales: torch.Tensor
|
||||||
|
g_idx: torch.Tensor
|
||||||
|
perm: torch.Tensor
|
||||||
|
is_full_k: bool
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinSparseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
MoE layer that uses a fused GPTQ-Marlin kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not (
|
||||||
|
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.renormalize = renormalize
|
||||||
|
|
||||||
|
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||||
|
prefix=prefix,
|
||||||
|
n_experts=n_experts,
|
||||||
|
names=[gate_proj_name, up_proj_name],
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = _load_expert_weights_row(
|
||||||
|
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bits = weights.loader.bits
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
return fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
w1=self.gate_up_proj.qweight,
|
||||||
|
w2=self.down_proj.qweight,
|
||||||
|
g_idx1=self.gate_up_proj.g_idx,
|
||||||
|
g_idx2=self.down_proj.g_idx,
|
||||||
|
perm1=self.gate_up_proj.perm,
|
||||||
|
perm2=self.down_proj.perm,
|
||||||
|
w1_scale=self.gate_up_proj.scales,
|
||||||
|
w2_scale=self.down_proj.scales,
|
||||||
|
is_full_k1=self.gate_up_proj.is_full_k,
|
||||||
|
is_full_k2=self.down_proj.is_full_k,
|
||||||
|
gating_output=gating_output,
|
||||||
|
topk=self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
use_grouped_topk=self.n_expert_group is not None,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_bits=self.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_multi_weights_col(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
names: List[str],
|
||||||
|
weights: Weights,
|
||||||
|
) -> GPTQMarlinMoEWeight:
|
||||||
|
moe_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_multi_weights_col(
|
||||||
|
[f"{prefix}.{i}.{name}" for name in names], 0
|
||||||
|
)
|
||||||
|
assert isinstance(weight, GPTQMarlinWeight)
|
||||||
|
moe_weight = _pack_weight(
|
||||||
|
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||||
|
)
|
||||||
|
assert moe_weight is not None
|
||||||
|
return moe_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _load_expert_weights_row(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
n_experts: int,
|
||||||
|
name: str,
|
||||||
|
weights: Weights,
|
||||||
|
) -> GPTQMarlinMoEWeight:
|
||||||
|
moe_weight = None
|
||||||
|
for i in range(n_experts):
|
||||||
|
weight = weights.get_weights_row(
|
||||||
|
f"{prefix}.{i}.{name}",
|
||||||
|
)
|
||||||
|
assert isinstance(weight, GPTQMarlinWeight)
|
||||||
|
moe_weight = _pack_weight(
|
||||||
|
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||||
|
)
|
||||||
|
assert moe_weight is not None
|
||||||
|
return moe_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _pack_weight(
|
||||||
|
*,
|
||||||
|
n_experts: int,
|
||||||
|
expert: int,
|
||||||
|
moe_weight: Optional[GPTQMarlinMoEWeight],
|
||||||
|
weight: GPTQMarlinWeight,
|
||||||
|
) -> GPTQMarlinMoEWeight:
|
||||||
|
if moe_weight is None:
|
||||||
|
qweight = torch.empty(
|
||||||
|
(n_experts,) + weight.qweight.shape,
|
||||||
|
dtype=weight.qweight.dtype,
|
||||||
|
device=weight.qweight.device,
|
||||||
|
)
|
||||||
|
qzeros = torch.empty(
|
||||||
|
(n_experts,) + weight.qzeros.shape,
|
||||||
|
dtype=weight.qzeros.dtype,
|
||||||
|
device=weight.qzeros.device,
|
||||||
|
)
|
||||||
|
scales = torch.empty(
|
||||||
|
(n_experts,) + weight.scales.shape,
|
||||||
|
dtype=weight.scales.dtype,
|
||||||
|
device=weight.scales.device,
|
||||||
|
)
|
||||||
|
g_idx = torch.empty(
|
||||||
|
(n_experts,) + weight.g_idx.shape,
|
||||||
|
dtype=weight.g_idx.dtype,
|
||||||
|
device=weight.g_idx.device,
|
||||||
|
)
|
||||||
|
perm = torch.empty(
|
||||||
|
(n_experts,) + weight.perm.shape,
|
||||||
|
dtype=weight.perm.dtype,
|
||||||
|
device=weight.perm.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_weight = GPTQMarlinMoEWeight(
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
g_idx=g_idx,
|
||||||
|
perm=perm,
|
||||||
|
is_full_k=weight.is_full_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
moe_weight.qweight[expert] = weight.qweight
|
||||||
|
moe_weight.qzeros[expert] = weight.qzeros
|
||||||
|
moe_weight.scales[expert] = weight.scales
|
||||||
|
moe_weight.g_idx[expert] = weight.g_idx
|
||||||
|
moe_weight.perm[expert] = weight.perm
|
||||||
|
|
||||||
|
return moe_weight
|
Loading…
Reference in New Issue