Add support for GPTQ-quantized MoE models using MoE Marlin (#2557)

This change add support for MoE models that use GPTQ quantization.
Currently only models with the following properties are supported:

- No `desc_act` with tensor parallelism, unless `group_size=-1`.
- No asymmetric quantization.
- No AWQ.
This commit is contained in:
Daniël de Kok 2024-09-30 11:14:32 +02:00 committed by GitHub
parent f9e561eced
commit 90a1d04a2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 870 additions and 30 deletions

View File

@ -978,16 +978,17 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1727353315, "lastModified": 1727681169,
"narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=", "narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=",
"owner": "huggingface", "owner": "danieldk",
"repo": "text-generation-inference-nix", "repo": "tgi-nix",
"rev": "1d42c4125ebafb87707118168995675cc5050b9d", "rev": "88ba4cfe378d8fb08222f640ff2b62ac0ee6569d",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "danieldk",
"repo": "text-generation-inference-nix", "ref": "moe-kernels-0.4.0",
"repo": "tgi-nix",
"type": "github" "type": "github"
} }
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix"; tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.4.0";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8808594,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37280273,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.26098633,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017137527,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2695312,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9238281,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48828125,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -0.34838867,
"special": false,
"text": "\n"
},
{
"id": 13940,
"logprob": -0.38916016,
"special": false,
"text": "``"
},
{
"id": 28832,
"logprob": 0.0,
"special": false,
"text": "`"
},
{
"id": 3371,
"logprob": -1.2529297,
"special": false,
"text": "json"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 28751,
"logprob": 0.0,
"special": false,
"text": "{"
},
{
"id": 13,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 2287,
"logprob": 0.0,
"special": false,
"text": " "
},
{
"id": 345,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 3134,
"logprob": -0.640625,
"special": false,
"text": "request"
}
],
"top_tokens": null
},
"generated_text": "Test request\n```json\n{\n \"request"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"special": false,
"text": "#"
},
{
"id": 3735,
"logprob": -2.8828125,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.37329102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.2602539,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.0017185211,
"special": false,
"text": "\n"
},
{
"id": 1064,
"logprob": -2.2753906,
"special": false,
"text": "##"
},
{
"id": 3735,
"logprob": -1.9316406,
"special": false,
"text": " Test"
},
{
"id": 2159,
"logprob": -0.48217773,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
}
]

View File

@ -0,0 +1,60 @@
import pytest
@pytest.fixture(scope="module")
def flash_mixtral_gptq_handle(launcher):
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
await flash_mixtral_gptq_handle.health(300)
return flash_mixtral_gptq_handle.client
@pytest.mark.asyncio
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_load(
flash_mixtral_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"
assert responses == response_snapshot

24
server/poetry.lock generated
View File

@ -1244,12 +1244,12 @@ files = [
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.3.1" version = "0.4.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:b679984a53807127f25af053ec0a2c07dec97ec196f76363a8bfdc3fbb3d1a9a"}, {file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"},
] ]
[package.dependencies] [package.dependencies]
@ -1259,16 +1259,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.3.1" version = "0.4.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:29684f81495f6e032085295c86d160022f03d5d9a9981446f311ca94fbbbc2cd"}, {file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"},
] ]
[package.dependencies] [package.dependencies]
@ -1278,16 +1278,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.3.1" version = "0.4.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9dfdbef48b5b7e97912aaa7420b1b694876a3281f5edfe7d4ca9a69e1f48bff2"}, {file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"},
] ]
[package.dependencies] [package.dependencies]
@ -1297,16 +1297,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.3.1" version = "0.4.0"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:f7d0fc8f191c905a668f3d2eb889999ee988048d08bfd7062d64bca3876588ae"}, {file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"},
] ]
[package.dependencies] [package.dependencies]
@ -1316,7 +1316,7 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mpmath" name = "mpmath"

View File

@ -47,10 +47,10 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
rich = "^13.7.1" rich = "^13.7.1"

View File

@ -10,13 +10,18 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
from text_generation_server.layers.moe.gptq_marlin import (
GPTQMarlinSparseMoELayer,
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
Weights, Weights,
UnquantizedWeight,
) )
if SYSTEM == "rocm": if SYSTEM == "rocm":
@ -205,14 +210,18 @@ class SparseMoELayer(nn.Module):
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader): ) or isinstance(weights.loader, HybridFP8UnquantLoader):
cls = UnquantizedSparseMoELayer cls = UnquantizedSparseMoELayer
# Once we wire up GPTQ-Marlin MoE: elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
# elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: cls = GPTQMarlinSparseMoELayer
# cls = GPTQMarlinSparseMoELayer
else: else:
raise ValueError( raise ValueError(
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
) )
log_once(
logger.info,
"Using MoE layer wih fused gemm",
)
self.moe = cls( self.moe = cls(
n_expert_group=n_expert_group, n_expert_group=n_expert_group,
n_experts=n_experts, n_experts=n_experts,
@ -237,6 +246,15 @@ class SparseMoELayer(nn.Module):
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) )
or isinstance(weights.loader, HybridFP8UnquantLoader) or isinstance(weights.loader, HybridFP8UnquantLoader)
# Once we wire up GPTQ-Marlin MoE: or (
# or isinstance(weights.loader, GPTQMarlinWeightsLoader) isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
desc_act=weights.loader.desc_act,
groupsize=weights.loader.groupsize,
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
use_tp=weights.process_group.size() > 1,
)
)
) )

View File

@ -0,0 +1,225 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
GPTQMarlinWeightsLoader,
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
else:
fused_marlin_moe = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def can_use_marlin_moe_gemm(
*,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
use_tp: bool,
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and has_sm_8_0
and quantize == "gptq"
and quant_method == "gptq"
and sym
and is_full_k(desc_act, groupsize, use_tp)
)
def is_full_k(desc_act: bool, groupsize: int, use_tp: bool):
if groupsize == -1:
return True
return not (desc_act and use_tp)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
is_full_k: bool
class GPTQMarlinSparseMoELayer(nn.Module):
"""
MoE layer that uses a fused GPTQ-Marlin kernel.
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
):
raise ValueError(
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
names=[gate_proj_name, up_proj_name],
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
)
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe(
x,
w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight,
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
perm1=self.gate_up_proj.perm,
perm2=self.down_proj.perm,
w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales,
is_full_k1=self.gate_up_proj.is_full_k,
is_full_k2=self.down_proj.is_full_k,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
names: List[str],
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{name}" for name in names], 0
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _pack_weight(
*,
n_experts: int,
expert: int,
moe_weight: Optional[GPTQMarlinMoEWeight],
weight: GPTQMarlinWeight,
) -> GPTQMarlinMoEWeight:
if moe_weight is None:
qweight = torch.empty(
(n_experts,) + weight.qweight.shape,
dtype=weight.qweight.dtype,
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
dtype=weight.scales.dtype,
device=weight.scales.device,
)
g_idx = torch.empty(
(n_experts,) + weight.g_idx.shape,
dtype=weight.g_idx.dtype,
device=weight.g_idx.device,
)
perm = torch.empty(
(n_experts,) + weight.perm.shape,
dtype=weight.perm.dtype,
device=weight.perm.device,
)
moe_weight = GPTQMarlinMoEWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
is_full_k=weight.is_full_k,
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm
return moe_weight