diff --git a/flake.lock b/flake.lock index 14e23b77..934ec3d0 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,17 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1727353315, - "narHash": "sha256-yZovq/6P8Z199r7e+NbTXyCqRgK6grRkLxYHWHnHckI=", - "owner": "huggingface", - "repo": "text-generation-inference-nix", - "rev": "1d42c4125ebafb87707118168995675cc5050b9d", + "lastModified": 1727681169, + "narHash": "sha256-ssoGLmRoyQ+8d5utr5fwLox+/eQ789iVtUj1xrukIC0=", + "owner": "danieldk", + "repo": "tgi-nix", + "rev": "88ba4cfe378d8fb08222f640ff2b62ac0ee6569d", "type": "github" }, "original": { - "owner": "huggingface", - "repo": "text-generation-inference-nix", + "owner": "danieldk", + "ref": "moe-kernels-0.4.0", + "repo": "tgi-nix", "type": "github" } } diff --git a/flake.nix b/flake.nix index 1b396453..ed2dedc9 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.4.0"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq.json b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq.json new file mode 100644 index 00000000..993bdadd --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.7089844, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.68847656, + "special": false, + "text": "\n" + }, + { + "id": 28771, + "logprob": -1.9394531, + "special": false, + "text": "#" + }, + { + "id": 3735, + "logprob": -2.8808594, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.37280273, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.26098633, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.0017137527, + "special": false, + "text": "\n" + }, + { + "id": 1064, + "logprob": -2.2695312, + "special": false, + "text": "##" + }, + { + "id": 3735, + "logprob": -1.9238281, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.48828125, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# Test request\n\n## Test request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_all_params.json b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_all_params.json new file mode 100644 index 00000000..94411eef --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": -0.34838867, + "special": false, + "text": "\n" + }, + { + "id": 13940, + "logprob": -0.38916016, + "special": false, + "text": "``" + }, + { + "id": 28832, + "logprob": 0.0, + "special": false, + "text": "`" + }, + { + "id": 3371, + "logprob": -1.2529297, + "special": false, + "text": "json" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 28751, + "logprob": 0.0, + "special": false, + "text": "{" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 2287, + "logprob": 0.0, + "special": false, + "text": " " + }, + { + "id": 345, + "logprob": 0.0, + "special": false, + "text": " \"" + }, + { + "id": 3134, + "logprob": -0.640625, + "special": false, + "text": "request" + } + ], + "top_tokens": null + }, + "generated_text": "Test request\n```json\n{\n \"request" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_load.json b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_load.json new file mode 100644 index 00000000..19e306a3 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.7089844, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.68847656, + "special": false, + "text": "\n" + }, + { + "id": 28771, + "logprob": -1.9394531, + "special": false, + "text": "#" + }, + { + "id": 3735, + "logprob": -2.8828125, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.37329102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.2602539, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.0017185211, + "special": false, + "text": "\n" + }, + { + "id": 1064, + "logprob": -2.2753906, + "special": false, + "text": "##" + }, + { + "id": 3735, + "logprob": -1.9316406, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.48217773, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# Test request\n\n## Test request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.7089844, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.68847656, + "special": false, + "text": "\n" + }, + { + "id": 28771, + "logprob": -1.9394531, + "special": false, + "text": "#" + }, + { + "id": 3735, + "logprob": -2.8828125, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.37329102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.2602539, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.0017185211, + "special": false, + "text": "\n" + }, + { + "id": 1064, + "logprob": -2.2753906, + "special": false, + "text": "##" + }, + { + "id": 3735, + "logprob": -1.9316406, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.48217773, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# Test request\n\n## Test request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.7089844, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.68847656, + "special": false, + "text": "\n" + }, + { + "id": 28771, + "logprob": -1.9394531, + "special": false, + "text": "#" + }, + { + "id": 3735, + "logprob": -2.8828125, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.37329102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.2602539, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.0017185211, + "special": false, + "text": "\n" + }, + { + "id": 1064, + "logprob": -2.2753906, + "special": false, + "text": "##" + }, + { + "id": 3735, + "logprob": -1.9316406, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.48217773, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# Test request\n\n## Test request" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 3735, + "logprob": -11.0078125, + "text": "Test" + }, + { + "id": 2159, + "logprob": -13.59375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -1.7089844, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.68847656, + "special": false, + "text": "\n" + }, + { + "id": 28771, + "logprob": -1.9394531, + "special": false, + "text": "#" + }, + { + "id": 3735, + "logprob": -2.8828125, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.37329102, + "special": false, + "text": " request" + }, + { + "id": 13, + "logprob": -0.2602539, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.0017185211, + "special": false, + "text": "\n" + }, + { + "id": 1064, + "logprob": -2.2753906, + "special": false, + "text": "##" + }, + { + "id": 3735, + "logprob": -1.9316406, + "special": false, + "text": " Test" + }, + { + "id": 2159, + "logprob": -0.48217773, + "special": false, + "text": " request" + } + ], + "top_tokens": null + }, + "generated_text": "\n\n# Test request\n\n## Test request" + } +] diff --git a/integration-tests/models/test_flash_mixtral_gptq.py b/integration-tests/models/test_flash_mixtral_gptq.py new file mode 100644 index 00000000..eb880628 --- /dev/null +++ b/integration-tests/models/test_flash_mixtral_gptq.py @@ -0,0 +1,60 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mixtral_gptq_handle(launcher): + with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mixtral_gptq(flash_mixtral_gptq_handle): + await flash_mixtral_gptq_handle.health(300) + return flash_mixtral_gptq_handle.client + + +@pytest.mark.asyncio +async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot): + response = await flash_mixtral_gptq.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot): + response = await flash_mixtral_gptq.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_gptq_load( + flash_mixtral_gptq, generate_load, response_snapshot +): + responses = await generate_load( + flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + + assert responses == response_snapshot diff --git a/server/poetry.lock b/server/poetry.lock index 8d0e31f8..e1b0b3d5 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1244,12 +1244,12 @@ files = [ [[package]] name = "moe-kernels" -version = "0.3.1" +version = "0.4.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:b679984a53807127f25af053ec0a2c07dec97ec196f76363a8bfdc3fbb3d1a9a"}, + {file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"}, ] [package.dependencies] @@ -1259,16 +1259,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.3.1" +version = "0.4.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:29684f81495f6e032085295c86d160022f03d5d9a9981446f311ca94fbbbc2cd"}, + {file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"}, ] [package.dependencies] @@ -1278,16 +1278,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.3.1" +version = "0.4.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9dfdbef48b5b7e97912aaa7420b1b694876a3281f5edfe7d4ca9a69e1f48bff2"}, + {file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"}, ] [package.dependencies] @@ -1297,16 +1297,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.3.1" +version = "0.4.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:f7d0fc8f191c905a668f3d2eb889999ee988048d08bfd7062d64bca3876588ae"}, + {file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"}, ] [package.dependencies] @@ -1316,7 +1316,7 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mpmath" diff --git a/server/pyproject.toml b/server/pyproject.toml index 6bdd2385..9be06fa5 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -47,10 +47,10 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] rich = "^13.7.1" diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 7e8ac2c8..ca71ebab 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -10,13 +10,18 @@ from text_generation_server.layers import ( TensorParallelRowLinear, ) from text_generation_server.layers.fp8 import HybridFP8UnquantLoader +from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader +from text_generation_server.layers.moe.gptq_marlin import ( + GPTQMarlinSparseMoELayer, + can_use_marlin_moe_gemm, +) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( DefaultWeightsLoader, - UnquantizedWeight, Weights, + UnquantizedWeight, ) if SYSTEM == "rocm": @@ -205,14 +210,18 @@ class SparseMoELayer(nn.Module): and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): cls = UnquantizedSparseMoELayer - # Once we wire up GPTQ-Marlin MoE: - # elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: - # cls = GPTQMarlinSparseMoELayer + elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: + cls = GPTQMarlinSparseMoELayer else: raise ValueError( f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" ) + log_once( + logger.info, + "Using MoE layer wih fused gemm", + ) + self.moe = cls( n_expert_group=n_expert_group, n_experts=n_experts, @@ -237,6 +246,15 @@ class SparseMoELayer(nn.Module): and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader) - # Once we wire up GPTQ-Marlin MoE: - # or isinstance(weights.loader, GPTQMarlinWeightsLoader) + or ( + isinstance(weights.loader, GPTQMarlinWeightsLoader) + and can_use_marlin_moe_gemm( + desc_act=weights.loader.desc_act, + groupsize=weights.loader.groupsize, + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + use_tp=weights.process_group.size() > 1, + ) + ) ) diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py new file mode 100644 index 00000000..3fc06cb2 --- /dev/null +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -0,0 +1,225 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.marlin.gptq import ( + GPTQMarlinWeight, + GPTQMarlinWeightsLoader, +) + +if SYSTEM == "cuda": + from moe_kernels.fused_marlin_moe import fused_marlin_moe +else: + fused_marlin_moe = None + + +try: + major, _minor = torch.cuda.get_device_capability() + has_sm_8_0 = major >= 8 +except Exception: + has_sm_8_0 = False + + +def can_use_marlin_moe_gemm( + *, + desc_act: bool, + groupsize: int, + quant_method: str, + quantize: str, + sym: bool, + use_tp: bool, +): + return ( + SYSTEM == "cuda" + and fused_marlin_moe is not None + and has_sm_8_0 + and quantize == "gptq" + and quant_method == "gptq" + and sym + and is_full_k(desc_act, groupsize, use_tp) + ) + + +def is_full_k(desc_act: bool, groupsize: int, use_tp: bool): + if groupsize == -1: + return True + return not (desc_act and use_tp) + + +@dataclass +class GPTQMarlinMoEWeight: + qweight: torch.Tensor + qzeros: torch.Tensor + scales: torch.Tensor + g_idx: torch.Tensor + perm: torch.Tensor + is_full_k: bool + + +class GPTQMarlinSparseMoELayer(nn.Module): + """ + MoE layer that uses a fused GPTQ-Marlin kernel. + """ + + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + if not ( + isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym + ): + raise ValueError( + f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" + ) + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + self.gate_up_proj = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + names=[gate_proj_name, up_proj_name], + weights=weights, + ) + + self.down_proj = _load_expert_weights_row( + prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights + ) + + self.bits = weights.loader.bits + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_marlin_moe( + x, + w1=self.gate_up_proj.qweight, + w2=self.down_proj.qweight, + g_idx1=self.gate_up_proj.g_idx, + g_idx2=self.down_proj.g_idx, + perm1=self.gate_up_proj.perm, + perm2=self.down_proj.perm, + w1_scale=self.gate_up_proj.scales, + w2_scale=self.down_proj.scales, + is_full_k1=self.gate_up_proj.is_full_k, + is_full_k2=self.down_proj.is_full_k, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + num_bits=self.bits, + ) + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + names: List[str], + weights: Weights, +) -> GPTQMarlinMoEWeight: + moe_weight = None + for i in range(n_experts): + weight = weights.get_multi_weights_col( + [f"{prefix}.{i}.{name}" for name in names], 0 + ) + assert isinstance(weight, GPTQMarlinWeight) + moe_weight = _pack_weight( + n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight + ) + assert moe_weight is not None + return moe_weight + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> GPTQMarlinMoEWeight: + moe_weight = None + for i in range(n_experts): + weight = weights.get_weights_row( + f"{prefix}.{i}.{name}", + ) + assert isinstance(weight, GPTQMarlinWeight) + moe_weight = _pack_weight( + n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight + ) + assert moe_weight is not None + return moe_weight + + +def _pack_weight( + *, + n_experts: int, + expert: int, + moe_weight: Optional[GPTQMarlinMoEWeight], + weight: GPTQMarlinWeight, +) -> GPTQMarlinMoEWeight: + if moe_weight is None: + qweight = torch.empty( + (n_experts,) + weight.qweight.shape, + dtype=weight.qweight.dtype, + device=weight.qweight.device, + ) + qzeros = torch.empty( + (n_experts,) + weight.qzeros.shape, + dtype=weight.qzeros.dtype, + device=weight.qzeros.device, + ) + scales = torch.empty( + (n_experts,) + weight.scales.shape, + dtype=weight.scales.dtype, + device=weight.scales.device, + ) + g_idx = torch.empty( + (n_experts,) + weight.g_idx.shape, + dtype=weight.g_idx.dtype, + device=weight.g_idx.device, + ) + perm = torch.empty( + (n_experts,) + weight.perm.shape, + dtype=weight.perm.dtype, + device=weight.perm.device, + ) + + moe_weight = GPTQMarlinMoEWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + perm=perm, + is_full_k=weight.is_full_k, + ) + + moe_weight.qweight[expert] = weight.qweight + moe_weight.qzeros[expert] = weight.qzeros + moe_weight.scales[expert] = weight.scales + moe_weight.g_idx[expert] = weight.g_idx + moe_weight.perm[expert] = weight.perm + + return moe_weight