diff --git a/flake.lock b/flake.lock index aacdd30e..76b4ca2f 100644 --- a/flake.lock +++ b/flake.lock @@ -978,15 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1728381423, - "narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", + "lastModified": 1729531056, + "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e", + "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1", "type": "github" }, "original": { "owner": "huggingface", + "ref": "marlin-kernels-0.3.0", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index f26a983e..5c05bfae 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json index c55dd593..b82882c0 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json @@ -11,27 +11,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.1875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.93359375, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.875, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1796875, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -39,66 +39,66 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.109375, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.079956055, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.2763672, + "logprob": -0.028808594, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37548828, + "logprob": -0.013671875, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4628906, + "logprob": -0.69921875, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02885437, + "logprob": -0.0005874634, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.2565918, + "logprob": -0.026855469, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0063438416, + "logprob": -0.00020885468, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3056641, + "logprob": -0.17773438, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.6035156, + "id": 18065, + "logprob": -0.703125, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json index d06d6e56..8bce3e10 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "eos_token", - "generated_tokens": 3, + "finish_reason": "length", + "generated_tokens": 10, "prefill": [ { "id": 128000, @@ -11,22 +11,22 @@ }, { "id": 374, - "logprob": -22.96875, + "logprob": -18.0, "text": " is" }, { "id": 5655, - "logprob": -10.71875, + "logprob": -11.75, "text": " deep" }, { "id": 6975, - "logprob": -2.6992188, + "logprob": -2.0625, "text": " learning" }, { "id": 30, - "logprob": -4.8398438, + "logprob": -6.0, "text": "?" } ], @@ -34,24 +34,66 @@ "tokens": [ { "id": 720, - "logprob": -0.4411621, + "logprob": 0.0, "special": false, "text": " \n" }, { - "id": 220, - "logprob": -0.35864258, + "id": 34564, + "logprob": -0.11279297, "special": false, - "text": " " + "text": "Deep" }, { - "id": 128001, + "id": 6975, + "logprob": -0.16015625, + "special": false, + "text": " learning" + }, + { + "id": 320, + "logprob": -0.25195312, + "special": false, + "text": " (" + }, + { + "id": 16931, + "logprob": -1.703125, + "special": false, + "text": "DL" + }, + { + "id": 8, "logprob": 0.0, - "special": true, - "text": "<|end_of_text|>" + "special": false, + "text": ")" + }, + { + "id": 374, + "logprob": -1.140625, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": 0.0, + "special": false, + "text": " a" + }, + { + "id": 1207, + "logprob": -1.3125, + "special": false, + "text": " sub" + }, + { + "id": 2630, + "logprob": 0.0, + "special": false, + "text": "field" } ], "top_tokens": null }, - "generated_text": "What is deep learning? \n " + "generated_text": "What is deep learning? \nDeep learning (DL) is a subfield" } diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json index 46670819..c7acee46 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json @@ -12,27 +12,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.1875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.93359375, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.875, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1796875, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -40,68 +40,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.109375, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.0047912598, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.025512695, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.012145996, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.72265625, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0005760193, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02722168, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00023651123, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.17285156, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.703125, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -116,27 +116,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -144,68 +144,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -220,27 +220,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -248,68 +248,68 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" }, { "details": { @@ -324,27 +324,27 @@ }, { "id": 3923, - "logprob": -5.6328125, + "logprob": -6.21875, "text": "What" }, { "id": 374, - "logprob": -1.2265625, + "logprob": -0.95703125, "text": " is" }, { "id": 5655, - "logprob": -9.1015625, + "logprob": -9.9375, "text": " deep" }, { "id": 6975, - "logprob": -1.8085938, + "logprob": -1.1328125, "text": " learning" }, { "id": 30, - "logprob": -1.0439453, + "logprob": -1.75, "text": "?" } ], @@ -352,67 +352,67 @@ "tokens": [ { "id": 18682, - "logprob": -2.1992188, + "logprob": -1.1796875, "special": false, "text": " Deep" }, { "id": 6975, - "logprob": -0.07897949, + "logprob": -0.005432129, "special": false, "text": " learning" }, { "id": 374, - "logprob": -0.27734375, + "logprob": -0.02758789, "special": false, "text": " is" }, { "id": 264, - "logprob": -0.37402344, + "logprob": -0.013366699, "special": false, "text": " a" }, { "id": 27084, - "logprob": -1.4511719, + "logprob": -0.6953125, "special": false, "text": " subset" }, { "id": 315, - "logprob": -0.02909851, + "logprob": -0.0004863739, "special": false, "text": " of" }, { "id": 5780, - "logprob": -0.25854492, + "logprob": -0.02709961, "special": false, "text": " machine" }, { "id": 6975, - "logprob": -0.0061798096, + "logprob": -0.00022506714, "special": false, "text": " learning" }, { "id": 430, - "logprob": -1.3046875, + "logprob": -0.19726562, "special": false, "text": " that" }, { - "id": 374, - "logprob": -1.5537109, + "id": 18065, + "logprob": -0.77734375, "special": false, - "text": " is" + "text": " involves" } ], "top_tokens": null }, - "generated_text": " Deep learning is a subset of machine learning that is" + "generated_text": " Deep learning is a subset of machine learning that involves" } ] diff --git a/integration-tests/models/test_flash_llama_fp8_kv_cache.py b/integration-tests/models/test_flash_llama_fp8_kv_cache.py index 05e9f0dd..ccd7f78f 100644 --- a/integration-tests/models/test_flash_llama_fp8_kv_cache.py +++ b/integration-tests/models/test_flash_llama_fp8_kv_cache.py @@ -4,7 +4,9 @@ import pytest @pytest.fixture(scope="module") def flash_llama_fp8_kv_cache_handle(launcher): with launcher( - "meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2" + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + num_shard=2, + kv_cache_dtype="fp8_e4m3fn", ) as handle: yield handle @@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps assert ( response.generated_text - == " Deep learning is a subset of machine learning that is" + == " Deep learning is a subset of machine learning that involves" ) assert response.details.generated_tokens == 10 assert response == response_snapshot @@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load( assert len(responses) == 4 assert ( responses[0].generated_text - == " Deep learning is a subset of machine learning that is" + == " Deep learning is a subset of machine learning that involves" ) assert all( [r.generated_text == responses[0].generated_text for r in responses] diff --git a/server/poetry.lock b/server/poetry.lock index 80fe72ba..1293e883 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1215,12 +1215,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"}, ] [package.dependencies] @@ -1228,16 +1228,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"}, ] [package.dependencies] @@ -1245,16 +1245,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"}, ] [package.dependencies] @@ -1262,16 +1262,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.2.0" +version = "0.3.0" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, + {file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"}, ] [package.dependencies] @@ -1279,7 +1279,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" diff --git a/server/pyproject.toml b/server/pyproject.toml index 6ea4718d..d08d0b8f 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0" numpy = "^1.26" marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index b1d7b864..ebe32042 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -28,10 +28,11 @@ else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") # KVCache needs `reshape_and_cache`, so ensure that it is defined already. -from .kv_cache import KVCache +from .kv_cache import KVCache, get_kv_scales __all__ = [ "attention", + "get_kv_scales", "paged_attention", "SUPPORTS_WINDOWING", "KVCache", diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 08326c82..d705afb0 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,5 @@ import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( ATTENTION, @@ -8,6 +8,7 @@ from text_generation_server.models.globals import ( from text_generation_server.layers.attention import Seqlen from typing import Optional + major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 _PARTITION_SIZE = 512 @@ -21,6 +22,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -46,6 +49,8 @@ def paged_attention( num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + can_scale = kv_cache.can_scale(kv_scales) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -60,6 +65,8 @@ def paged_attention( paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, + k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, + v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, ) elif ATTENTION == "flashdecoding": max_q = 1 @@ -205,6 +212,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, @@ -212,6 +220,8 @@ def attention( causal: bool = True, softcap: Optional[float] = None, ): + can_scale = kv_cache.can_scale(kv_scales) + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, @@ -228,6 +238,8 @@ def attention( logits_soft_cap=softcap, sm_scale=softmax_scale, window_left=window_size_left, + k_scale=kv_scales.key_scale_cpu if can_scale else 1.0, + v_scale=kv_scales.value_scale_cpu if can_scale else 1.0, ) # If we are using flashdecoding or paged, we always use flash-attn for diff --git a/server/text_generation_server/layers/attention/flashinfer.py b/server/text_generation_server/layers/attention/flashinfer.py index d603c6f5..26a72d9b 100644 --- a/server/text_generation_server/layers/attention/flashinfer.py +++ b/server/text_generation_server/layers/attention/flashinfer.py @@ -204,6 +204,7 @@ def use_decode_state( num_kv_heads: int, head_size: int, page_size: int, + kv_cache_dtype: torch.dtype, dtype: torch.dtype, window_left: int, ): @@ -240,7 +241,7 @@ def use_decode_state( num_kv_heads=num_kv_heads, head_dim=head_size, page_size=page_size, - data_type=dtype, + data_type=kv_cache_dtype, q_data_type=dtype, window_left=window_left, ) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index e76bb1f4..677f3f56 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,6 +1,6 @@ import intel_extension_for_pytorch as ipex import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional @@ -14,6 +14,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, @@ -55,6 +56,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): if softcap is not None: diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index d64302c6..9d739da5 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -1,8 +1,38 @@ from typing import Tuple +from dataclasses import dataclass, field +from loguru import logger import torch + +from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights + + +@dataclass +class KVScales: + """ + Key-value scales for FP8 KV cache. + + This data class stores key and value scales both as a GPU tensor and + as a GPU float. This inconvenience is necessary because some functions + (e.g. scaling kernels) take scales as a GPU tensor, whereas others + (e.g. flashinfer) take scales as a CPU scalar. + """ + + key_scale: torch.Tensor + value_scale: torch.Tensor + key_scale_cpu: float = field(init=False) + value_scale_cpu: float = field(init=False) + + def __post_init__(self): + if self.key_scale.numel() != 1 or self.value_scale.numel() != 1: + raise ValueError("Key and value scales must be scalar tensors.") + + self.key_scale_cpu = self.key_scale.item() + self.value_scale_cpu = self.value_scale.item() class KVCache: @@ -76,6 +106,33 @@ class KVCache: ), ) + def can_scale(self, kv_scales: KVScales) -> bool: + """Check if the cache can be scaled by the given scales.""" + if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0: + return False + elif ( + self.dtype == torch.float8_e4m3fn + and ATTENTION == "flashinfer" + and SYSTEM == "cuda" + ): + log_once( + logger.info, + "Using FP8 KV cache scales", + ) + return True + else: + # We have scales, but not the correct FP8 cache type, so warn once. + log_once( + logger.info, + "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", + ) + return False + + @property + def dtype(self): + """Get the data type of the cache.""" + return self.kv_cache[0].dtype + @property def key(self): """Get the key cache.""" @@ -94,17 +151,33 @@ class KVCache: key: torch.Tensor, value: torch.Tensor, slots: torch.Tensor, + kv_scales: KVScales, ): """Store the key and value at the given slots.""" key_cache = self.kv_cache[0] value_cache = self.kv_cache[1] + if self.can_scale(kv_scales): + if kv_scales.key_scale_cpu != 1.0: + key = fp8_quantize( + key.float(), + scale=kv_scales.key_scale, + qdtype=self.dtype, + scalar=True, + )[0] + if kv_scales.value_scale_cpu != 1.0: + value = fp8_quantize( + value.float(), + scale=kv_scales.value_scale, + qdtype=self.dtype, + scalar=True, + )[0] + if ATTENTION in {"flashdecoding", "flashinfer"}: - # TODO: add scale key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) - if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: + if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so # put as raw data instead. key_cache = key_cache.view(torch.uint8) @@ -151,5 +224,23 @@ def paged_reshape_and_cache( ) else: raise NotImplementedError( - f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention" + f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported" ) + + +def get_kv_scales(weights: Weights, prefix: str) -> KVScales: + """Load KV cache scales.""" + + key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device) + value_scale = key_scale + if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor( + f"{prefix}.v_scale" + ): + key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float() + value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float() + elif weights.has_tensor(f"{prefix}.kv_scale"): + # Fall back to older more coarse-grained scale when available. + key_scale = weights.get_tensor(f"{prefix}.kv_scale").float() + value_scale = key_scale + + return KVScales(key_scale=key_scale, value_scale=value_scale) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 47bf5539..ea11c2c2 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,7 +1,7 @@ import os from typing import Optional import torch -from text_generation_server.layers.attention.kv_cache import KVCache +from text_generation_server.layers.attention.kv_cache import KVCache, KVScales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master @@ -36,6 +36,8 @@ def paged_attention( block_tables: torch.Tensor, seqlen: Seqlen, max_s: int, + *, + kv_scales: KVScales, softcap: Optional[float] = None, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -210,6 +212,7 @@ def attention( key: torch.Tensor, value: torch.Tensor, kv_cache: KVCache, + kv_scales: KVScales, seqlen: Seqlen, block_tables: torch.Tensor, softmax_scale: float, diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 18a40afa..a58c7f7b 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -26,6 +26,12 @@ def is_fbgemm_gpu_available(): return False +try: + import marlin_kernels +except ImportError: + marlin_kernels = None + + if is_fbgemm_gpu_available(): if SYSTEM == "cuda": major, _ = torch.cuda.get_device_capability() @@ -94,6 +100,17 @@ def fp8_quantize( ) return qweight, scale + if marlin_kernels is not None: + shape = weight.shape + qweight, scale = marlin_kernels.scaled_fp8_quant( + weight.reshape(-1, shape[-1]), + dtype=qdtype, + scale=scale, + scale_ub=scale_upper_bound, + ) + + return qweight.reshape(shape), scale + # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 4eee5c20..68719106 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, Seqlen, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: @@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 4ee67741..f70bff4f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,6 +20,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM if SYSTEM != "ipex": @@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 97b3ea96..906a83a4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -34,6 +34,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) +from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale @@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module): ), ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.kv_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps ) @@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module): cos: torch.Tensor, sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: KVCache, block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, @@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) # Remove padding. diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index c962a2af..ebf1b80e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -39,6 +39,7 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module): seqlen, max_s, softcap=self.softcap, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index b127f284..ad3be80e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 2d005734..906b34c1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -36,6 +36,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module): head_size=self.head_size, num_heads=self.num_heads, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 2eef1ded..692f8ca3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,6 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module): prefix=prefix, weights=weights, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, @@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - kv_cache.store(key=key, value=value, slots=slots) + kv_cache.store( + key=key, + value=value, + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module): key=key, value=value, kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 20841aeb..b26dd484 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,7 +27,10 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import KVCache +from text_generation_server.layers.attention import ( + KVCache, + get_kv_scales, +) from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -179,6 +182,8 @@ class FlashLlamaAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights, index) self.index = index + self.kv_scales = get_kv_scales(weights, f"{prefix}") + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -224,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -233,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module): query=query, key=kv[:, 0], value=kv[:, 1], + kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, block_tables=block_tables, @@ -248,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 7bad429c..c66c732f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,6 +26,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, @@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module): ], process_group=weights.process_group, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, @@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 712b7bc4..a45dd1e6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2ce69d8e..2301b63c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module): head_size=self.head_size, hidden_size=self.hidden_size, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) @@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) + kv_cache.store( + key=qkv[:, 1], + value=qkv[:, 2], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module): key=qkv[:, 1], value=qkv[:, 2], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 62d524c9..7382a7cb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -18,6 +18,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") # in llama the dense layer is called "o_proj" and has bias=False self.dense = TensorParallelRowLinear.load( @@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module): query=query, key=kv[:, 0], value=kv[:, 1], + kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, block_tables=block_tables, @@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 905dd98f..ab2a177d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -16,6 +16,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", @@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 8085ff89..2dcd1bf3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,6 +12,7 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( @@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) + kv_cache.store( + key=kv[:, 0], + value=kv[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module): key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module): weights=weights, bias=config.bias, ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias ) @@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) kv_cache.store( - key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), + slots=slots, + kv_scales=self.kv_scales, ) # Prefill @@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module): key=kv[:, :, 0], value=kv[:, :, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.dense( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 52119b64..ed053eb6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,6 +17,7 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module): self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) @@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) + kv_cache.store( + key=key_value[:, 0], + value=key_value[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module): key=key_value[:, 0], value=key_value[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index fe339aee..c793982d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -38,6 +38,7 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) +from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module): ) self.query_key_value = load_attention(config, prefix, weights) + self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, @@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) + kv_cache.store( + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + slots=slots, + kv_scales=self.kv_scales, + ) # Prefill if cu_seqlen_prefill is not None: @@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module): key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], kv_cache=kv_cache, + kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, @@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module): block_tables, seqlen, max_s, + kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b1270b44..b931671c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -2283,6 +2283,7 @@ class FlashCausalLM(Model): num_kv_heads=self.num_kv_heads, head_size=self.head_size, page_size=BLOCK_SIZE, + kv_cache_dtype=self.kv_cache_dtype, dtype=self.dtype, window_left=self.sliding_window, ) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 548591e5..aae64acf 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -207,7 +207,9 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): + def get_tensor( + self, tensor_name: str, to_device: bool = True, to_dtype: bool = True + ) -> torch.Tensor: filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name)