From 64142489b69d394cf4801d7265d4b2c3443225a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 8 Oct 2024 11:56:41 +0200 Subject: [PATCH] Add support for fused MoE Marlin for AWQ (#2616) * Add support for fused MoE Marlin for AWQ This uses the updated MoE Marlin kernels from vLLM. * Add integration test for AWQ MoE --- flake.lock | 7 +- flake.nix | 2 +- .../test_flash_mixtral_awq.json | 104 +++++ .../test_flash_mixtral_awq_all_params.json | 99 +++++ .../test_flash_mixtral_awq_load.json | 418 ++++++++++++++++++ .../models/test_flash_mixtral_awq.py | 73 +++ server/poetry.lock | 29 +- server/pyproject.toml | 8 +- .../layers/marlin/gptq.py | 2 +- .../layers/moe/__init__.py | 10 +- .../layers/moe/gptq_marlin.py | 37 +- 11 files changed, 749 insertions(+), 40 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json create mode 100644 integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json create mode 100644 integration-tests/models/test_flash_mixtral_awq.py diff --git a/flake.lock b/flake.lock index 04d386b3..69bdb736 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1728029332, - "narHash": "sha256-j0RX3a67lvi2PC5w6J5DHTxM+l96J/OV5sAf34IUfUo=", + "lastModified": 1728314485, + "narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "98049f853346ca780b81fee730715c90d33ac2b4", + "rev": "ef9a73a6f950213db60516ff8fe6d97ca89047b8", "type": "github" }, "original": { "owner": "huggingface", + "ref": "moe-kernels-0.6.0", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index edef442f..ef696f01 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/moe-kernels-0.6.0"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json new file mode 100644 index 00000000..9ca22e10 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.296875, + "text": "What" + }, + { + "id": 349, + "logprob": -0.97216797, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.9658203, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44384766, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.50878906, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.8876953, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15124512, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.16687012, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.8046875, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007205963, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.090026855, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030670166, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json new file mode 100644 index 00000000..38ab7263 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 349, + "logprob": -13.921875, + "text": "is" + }, + { + "id": 3534, + "logprob": -11.2265625, + "text": "deep" + }, + { + "id": 5168, + "logprob": -2.3886719, + "text": "learning" + }, + { + "id": 28804, + "logprob": -4.7109375, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.5229492, + "special": false, + "text": "Deep" + }, + { + "id": 17504, + "logprob": 0.0, + "special": false, + "text": " Learning" + }, + { + "id": 349, + "logprob": -0.5151367, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": 0.0, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": 0.0, + "special": false, + "text": " of" + }, + { + "id": 13253, + "logprob": -1.3359375, + "special": false, + "text": " Machine" + }, + { + "id": 17504, + "logprob": 0.0, + "special": false, + "text": " Learning" + }, + { + "id": 28725, + "logprob": 0.0, + "special": false, + "text": "," + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning," +} diff --git a/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json new file mode 100644 index 00000000..329d73ee --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.296875, + "text": "What" + }, + { + "id": 349, + "logprob": -0.97216797, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.9658203, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44384766, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.50878906, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.8876953, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15136719, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030273438, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1665039, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.1776123, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.8076172, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.090148926, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030670166, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 1824, + "logprob": -12.34375, + "text": "What" + }, + { + "id": 349, + "logprob": -0.96728516, + "text": "is" + }, + { + "id": 3534, + "logprob": -10.1796875, + "text": "deep" + }, + { + "id": 5168, + "logprob": -0.97265625, + "text": "learning" + }, + { + "id": 28804, + "logprob": -0.44189453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 13, + "logprob": -0.51220703, + "special": false, + "text": "\n" + }, + { + "id": 13, + "logprob": -0.87402344, + "special": false, + "text": "\n" + }, + { + "id": 23229, + "logprob": -0.15039062, + "special": false, + "text": "Deep" + }, + { + "id": 5168, + "logprob": -0.030288696, + "special": false, + "text": " learning" + }, + { + "id": 349, + "logprob": -0.1652832, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.17858887, + "special": false, + "text": " a" + }, + { + "id": 19804, + "logprob": -0.81103516, + "special": false, + "text": " subset" + }, + { + "id": 302, + "logprob": -0.007183075, + "special": false, + "text": " of" + }, + { + "id": 5599, + "logprob": -0.08880615, + "special": false, + "text": " machine" + }, + { + "id": 5168, + "logprob": -0.0030612946, + "special": false, + "text": " learning" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a subset of machine learning" + } +] diff --git a/integration-tests/models/test_flash_mixtral_awq.py b/integration-tests/models/test_flash_mixtral_awq.py new file mode 100644 index 00000000..ab1e0f00 --- /dev/null +++ b/integration-tests/models/test_flash_mixtral_awq.py @@ -0,0 +1,73 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_mixtral_awq_handle(launcher): + with launcher("casperhansen/mixtral-instruct-awq", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_mixtral_awq(flash_mixtral_awq_handle): + await flash_mixtral_awq_handle.health(300) + return flash_mixtral_awq_handle.client + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot): + response = await flash_mixtral_awq.generate( + "What is deep learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text == "\n\nDeep learning is a subset of machine learning" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot): + response = await flash_mixtral_awq.generate( + "What is deep learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert ( + response.generated_text + == "What is deep learning?\nDeep Learning is a subset of Machine Learning," + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +async def test_flash_mixtral_awq_load( + flash_mixtral_awq, generate_load, response_snapshot +): + responses = await generate_load( + flash_mixtral_awq, "What is deep learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert responses[0].details.generated_tokens == 10 + assert ( + responses[0].generated_text + == "\n\nDeep learning is a subset of machine learning" + ) + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"{[r.generated_text for r in responses]}" + + assert responses == response_snapshot diff --git a/server/poetry.lock b/server/poetry.lock index 64e45765..08f74999 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1269,12 +1269,12 @@ files = [ [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:3fc0475bb3b9c09bbf08f6f6e9767d10eaba55b558f67a605fe70ae0cbb5e6a4"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:f28fd2a56c3ac7bfe74bc44cc7c8c0791a2644ad689b084ea4ed6decb7f41c25"}, ] [package.dependencies] @@ -1284,16 +1284,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:8ca72a064ceb84a23a3437cc6e6363907ad41588877f6acb1febc010fc7beb22"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:db475948fd9f7a8647aa3f73256ff4d3bb111425305bcd0b0d3559ccc75b8937"}, ] [package.dependencies] @@ -1303,16 +1303,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:d302d6b16bb4905b2312dc68da6a6f51e87d0cd3c4bf1f23d995501162399a8e"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:364be07c06aafbab1f51d9e26d9a4ff658defe1462a4c645abaf7b895ed163a8"}, ] [package.dependencies] @@ -1322,16 +1322,16 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "moe-kernels" -version = "0.4.0" +version = "0.6.0" description = "MoE kernels" optional = true python-versions = ">=3.7" files = [ - {file = "moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:6aee3e723efa5113c338b40e6cb20fa62da6c442c65c1a6cc97751d34158a93a"}, + {file = "moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:81e7fa25fb5ed5336f5151994f5e3f600df7e166fe013576968c59415e442894"}, ] [package.dependencies] @@ -1341,7 +1341,7 @@ triton = "*" [package.source] type = "url" -url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mpmath" @@ -3402,11 +3402,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] diff --git a/server/pyproject.toml b/server/pyproject.toml index ef67deb1..08f305e6 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -47,10 +47,10 @@ marlin-kernels = [ { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.4.0/moe_kernels-0.4.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] rich = "^13.7.1" diff --git a/server/text_generation_server/layers/marlin/gptq.py b/server/text_generation_server/layers/marlin/gptq.py index 0a785d94..7245431f 100644 --- a/server/text_generation_server/layers/marlin/gptq.py +++ b/server/text_generation_server/layers/marlin/gptq.py @@ -43,7 +43,7 @@ def can_use_gptq_marlin( and quant_method in {"awq", "gptq"} and bits in GPTQ_MARLIN_BITS and groupsize in GPTQ_MARLIN_GROUP_SIZES - # We only suppord asymmetric quantization for AWQ. + # We only support asymmetric quantization for AWQ. and (sym or quant_method == "awq") ) diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index 2c46ca02..558d9ed9 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -210,11 +210,17 @@ class SparseMoELayer(nn.Module): and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader): cls = UnquantizedSparseMoELayer - elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: + elif isinstance( + weights.loader, GPTQMarlinWeightsLoader + ) and can_use_marlin_moe_gemm( + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + ): cls = GPTQMarlinSparseMoELayer else: raise ValueError( - f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" + f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights" ) log_once( diff --git a/server/text_generation_server/layers/moe/gptq_marlin.py b/server/text_generation_server/layers/moe/gptq_marlin.py index 3217cdc2..3d4ca9d8 100644 --- a/server/text_generation_server/layers/moe/gptq_marlin.py +++ b/server/text_generation_server/layers/moe/gptq_marlin.py @@ -34,9 +34,10 @@ def can_use_marlin_moe_gemm( SYSTEM == "cuda" and fused_marlin_moe is not None and has_sm_8_0 - and quantize == "gptq" - and quant_method == "gptq" - and sym + and quantize in {"awq", "gptq"} + and quant_method in {"awq", "gptq"} + # We only support asymmetric quantization for AWQ. + and (sym or quant_method == "awq") ) @@ -72,10 +73,15 @@ class GPTQMarlinSparseMoELayer(nn.Module): super().__init__() if not ( - isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym + isinstance(weights.loader, GPTQMarlinWeightsLoader) + and can_use_marlin_moe_gemm( + quant_method=weights.loader.quant_method, + quantize=weights.loader.quantize, + sym=weights.loader.sym, + ) ): raise ValueError( - f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported" + f"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported" ) assert (n_expert_group is None) == ( @@ -102,17 +108,24 @@ class GPTQMarlinSparseMoELayer(nn.Module): def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: return fused_marlin_moe( - x, + hidden_states=x, w1=self.gate_up_proj.qweight, w2=self.down_proj.qweight, - g_idx1=self.gate_up_proj.g_idx, - g_idx2=self.down_proj.g_idx, - perm1=self.gate_up_proj.perm, - perm2=self.down_proj.perm, w1_scale=self.gate_up_proj.scales, w2_scale=self.down_proj.scales, - is_full_k1=self.gate_up_proj.is_full_k, - is_full_k2=self.down_proj.is_full_k, + w1_zeros=( + self.gate_up_proj.qzeros + if self.gate_up_proj.qzeros.numel() > 0 + else None + ), + w2_zeros=( + self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None + ), + g_idx1=self.gate_up_proj.g_idx, + g_idx2=self.down_proj.g_idx, + sort_indices1=self.gate_up_proj.perm, + sort_indices2=self.down_proj.perm, + is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k, gating_output=gating_output, topk=self.topk, renormalize=self.renormalize,