Add support for FP8 KV cache scales (#2628)

* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
This commit is contained in:
Daniël de Kok 2024-10-24 16:36:18 +02:00 committed by GitHub
parent 14a0df3a38
commit eab07f746c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 486 additions and 155 deletions

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1728381423, "lastModified": 1729531056,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=", "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"owner": "huggingface", "owner": "huggingface",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e", "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "huggingface", "owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"repo": "text-generation-inference-nix", "repo": "text-generation-inference-nix",
"type": "github" "type": "github"
} }

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix"; tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {

View File

@ -11,27 +11,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.1875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.93359375,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.875,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1796875,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -39,66 +39,66 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.109375,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.079956055, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.2763672, "logprob": -0.028808594,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37548828, "logprob": -0.013671875,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4628906, "logprob": -0.69921875,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02885437, "logprob": -0.0005874634,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.2565918, "logprob": -0.026855469,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0063438416, "logprob": -0.00020885468,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3056641, "logprob": -0.17773438,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.6035156, "logprob": -0.703125,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
} }

View File

@ -1,8 +1,8 @@
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
"finish_reason": "eos_token", "finish_reason": "length",
"generated_tokens": 3, "generated_tokens": 10,
"prefill": [ "prefill": [
{ {
"id": 128000, "id": 128000,
@ -11,22 +11,22 @@
}, },
{ {
"id": 374, "id": 374,
"logprob": -22.96875, "logprob": -18.0,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -10.71875, "logprob": -11.75,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -2.6992188, "logprob": -2.0625,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -4.8398438, "logprob": -6.0,
"text": "?" "text": "?"
} }
], ],
@ -34,24 +34,66 @@
"tokens": [ "tokens": [
{ {
"id": 720, "id": 720,
"logprob": -0.4411621, "logprob": 0.0,
"special": false, "special": false,
"text": " \n" "text": " \n"
}, },
{ {
"id": 220, "id": 34564,
"logprob": -0.35864258, "logprob": -0.11279297,
"special": false, "special": false,
"text": " " "text": "Deep"
}, },
{ {
"id": 128001, "id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0, "logprob": 0.0,
"special": true, "special": false,
"text": "<|end_of_text|>" "text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "What is deep learning? \n " "generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
} }

View File

@ -12,27 +12,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.1875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.93359375,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.875,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1796875,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -40,68 +40,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.109375,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.0047912598,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.025512695,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.012145996,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.72265625,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0005760193,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02722168,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00023651123,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.17285156,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.703125,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -116,27 +116,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -144,68 +144,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -220,27 +220,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -248,68 +248,68 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
}, },
{ {
"details": { "details": {
@ -324,27 +324,27 @@
}, },
{ {
"id": 3923, "id": 3923,
"logprob": -5.6328125, "logprob": -6.21875,
"text": "What" "text": "What"
}, },
{ {
"id": 374, "id": 374,
"logprob": -1.2265625, "logprob": -0.95703125,
"text": " is" "text": " is"
}, },
{ {
"id": 5655, "id": 5655,
"logprob": -9.1015625, "logprob": -9.9375,
"text": " deep" "text": " deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -1.8085938, "logprob": -1.1328125,
"text": " learning" "text": " learning"
}, },
{ {
"id": 30, "id": 30,
"logprob": -1.0439453, "logprob": -1.75,
"text": "?" "text": "?"
} }
], ],
@ -352,67 +352,67 @@
"tokens": [ "tokens": [
{ {
"id": 18682, "id": 18682,
"logprob": -2.1992188, "logprob": -1.1796875,
"special": false, "special": false,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.07897949, "logprob": -0.005432129,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 374, "id": 374,
"logprob": -0.27734375, "logprob": -0.02758789,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 264, "id": 264,
"logprob": -0.37402344, "logprob": -0.013366699,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 27084, "id": 27084,
"logprob": -1.4511719, "logprob": -0.6953125,
"special": false, "special": false,
"text": " subset" "text": " subset"
}, },
{ {
"id": 315, "id": 315,
"logprob": -0.02909851, "logprob": -0.0004863739,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5780, "id": 5780,
"logprob": -0.25854492, "logprob": -0.02709961,
"special": false, "special": false,
"text": " machine" "text": " machine"
}, },
{ {
"id": 6975, "id": 6975,
"logprob": -0.0061798096, "logprob": -0.00022506714,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 430, "id": 430,
"logprob": -1.3046875, "logprob": -0.19726562,
"special": false, "special": false,
"text": " that" "text": " that"
}, },
{ {
"id": 374, "id": 18065,
"logprob": -1.5537109, "logprob": -0.77734375,
"special": false, "special": false,
"text": " is" "text": " involves"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": " Deep learning is a subset of machine learning that is" "generated_text": " Deep learning is a subset of machine learning that involves"
} }
] ]

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher): def flash_llama_fp8_kv_cache_handle(launcher):
with launcher( with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2" "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
num_shard=2,
kv_cache_dtype="fp8_e4m3fn",
) as handle: ) as handle:
yield handle yield handle
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
assert ( assert (
response.generated_text response.generated_text
== " Deep learning is a subset of machine learning that is" == " Deep learning is a subset of machine learning that involves"
) )
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert response == response_snapshot assert response == response_snapshot
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
assert len(responses) == 4 assert len(responses) == 4
assert ( assert (
responses[0].generated_text responses[0].generated_text
== " Deep learning is a subset of machine learning that is" == " Deep learning is a subset of machine learning that involves"
) )
assert all( assert all(
[r.generated_text == responses[0].generated_text for r in responses] [r.generated_text == responses[0].generated_text for r in responses]

24
server/poetry.lock generated
View File

@ -1215,12 +1215,12 @@ files = [
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.2.0" version = "0.3.0"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:9a5afcf19b0f5917e43353cc19873fb3c4d4d0b924e2a95a37884f9ce208d0bd"}, {file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"},
] ]
[package.dependencies] [package.dependencies]
@ -1228,16 +1228,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.2.0" version = "0.3.0"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:1e64fcc7ebadfaffa60091ee9201ae3daaf5c1be3be60c8c054143a3dcb72d5d"}, {file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"},
] ]
[package.dependencies] [package.dependencies]
@ -1245,16 +1245,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.2.0" version = "0.3.0"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:e75f3ce9b1c13a4ed43a380d88e1d34d297259452db037ec1973ec33dc2eb78e"}, {file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"},
] ]
[package.dependencies] [package.dependencies]
@ -1262,16 +1262,16 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "marlin-kernels" name = "marlin-kernels"
version = "0.2.0" version = "0.3.0"
description = "Marlin quantization kernels" description = "Marlin quantization kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:2f99a27f70b391887ee6adffeeee7c3f4df7fac37393f9fb16d4cace2b3f6457"}, {file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"},
] ]
[package.dependencies] [package.dependencies]
@ -1279,7 +1279,7 @@ torch = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mdurl" name = "mdurl"

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26" numpy = "^1.26"
marlin-kernels = [ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -28,10 +28,11 @@ else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already. # KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache from .kv_cache import KVCache, get_kv_scales
__all__ = [ __all__ = [
"attention", "attention",
"get_kv_scales",
"paged_attention", "paged_attention",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",

View File

@ -1,5 +1,5 @@
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
@ -8,6 +8,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
@ -21,6 +22,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -46,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
can_scale = kv_cache.can_scale(kv_scales)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -60,6 +65,8 @@ def paged_attention(
paged_kv_cache=(kv_cache.key, kv_cache.value), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )
elif ATTENTION == "flashdecoding": elif ATTENTION == "flashdecoding":
max_q = 1 max_q = 1
@ -205,6 +212,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
@ -212,6 +220,8 @@ def attention(
causal: bool = True, causal: bool = True,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
can_scale = kv_cache.can_scale(kv_scales)
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state, prefill_with_paged_kv_state,
@ -228,6 +238,8 @@ def attention(
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left, window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
) )
# If we are using flashdecoding or paged, we always use flash-attn for # If we are using flashdecoding or paged, we always use flash-attn for

View File

@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
page_size: int, page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype, dtype: torch.dtype,
window_left: int, window_left: int,
): ):
@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_size, head_dim=head_size,
page_size=page_size, page_size=page_size,
data_type=dtype, data_type=kv_cache_dtype,
q_data_type=dtype, q_data_type=dtype,
window_left=window_left, window_left=window_left,
) )

View File

@ -1,6 +1,6 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
@ -14,6 +14,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
@ -55,6 +56,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
if softcap is not None: if softcap is not None:

View File

@ -1,8 +1,38 @@
from typing import Tuple from typing import Tuple
from dataclasses import dataclass, field
from loguru import logger
import torch import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache: class KVCache:
@ -76,6 +106,33 @@ class KVCache:
), ),
) )
def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property @property
def key(self): def key(self):
"""Get the key cache.""" """Get the key cache."""
@ -94,17 +151,33 @@ class KVCache:
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
kv_scales: KVScales,
): ):
"""Store the key and value at the given slots.""" """Store the key and value at the given slots."""
key_cache = self.kv_cache[0] key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1] value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}: if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype) key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype) value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead. # put as raw data instead.
key_cache = key_cache.view(torch.uint8) key_cache = key_cache.view(torch.uint8)
@ -151,5 +224,23 @@ def paged_reshape_and_cache(
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention" f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
) )
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,7 +1,7 @@
import os import os
from typing import Optional from typing import Optional
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
@ -36,6 +36,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -210,6 +212,7 @@ def attention(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,

View File

@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
return False return False
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available(): if is_fbgemm_gpu_available():
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
@ -94,6 +100,17 @@ def fp8_quantize(
) )
return qweight, scale return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False) # weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) finfo = torch.finfo(qdtype)

View File

@ -30,6 +30,7 @@ from text_generation_server.layers.attention import (
attention, attention,
Seqlen, Seqlen,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -227,6 +228,7 @@ class FlashCohereAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm: if self.use_qk_norm:
@ -289,7 +291,12 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -299,6 +306,7 @@ class FlashCohereAttention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -313,6 +321,7 @@ class FlashCohereAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -20,6 +20,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "ipex": if SYSTEM != "ipex":
@ -288,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -328,7 +330,12 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -338,6 +345,7 @@ class DbrxAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -352,6 +360,7 @@ class DbrxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -34,6 +34,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -230,6 +231,8 @@ class DeepseekV2Attention(torch.nn.Module):
), ),
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load( self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
) )
@ -258,7 +261,7 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: KVCache,
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
@ -319,7 +322,12 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -329,6 +337,7 @@ class DeepseekV2Attention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -343,6 +352,7 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
# Remove padding. # Remove padding.

View File

@ -39,6 +39,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -206,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -251,7 +253,12 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -261,6 +268,7 @@ class FlashGemma2Attention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -278,6 +286,7 @@ class FlashGemma2Attention(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -37,6 +37,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -185,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -222,7 +224,12 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGemmaAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -247,6 +255,7 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -36,6 +36,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -193,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
num_heads=self.num_heads, num_heads=self.num_heads,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -222,7 +224,12 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -232,6 +239,7 @@ class FlashGPT2Attention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -246,6 +254,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -24,6 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -138,6 +139,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix, prefix=prefix,
weights=weights, weights=weights,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -184,7 +186,12 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
kv_cache.store(key=key, value=value, slots=slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashGPTJAttention(torch.nn.Module):
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -208,6 +216,7 @@ class FlashGPTJAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,10 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import KVCache from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -179,6 +182,8 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
@ -224,7 +229,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -233,6 +243,7 @@ class FlashLlamaAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
@ -248,6 +259,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
@ -158,6 +159,7 @@ class MistralAttention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -208,7 +210,12 @@ class MistralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -218,6 +225,7 @@ class MistralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -233,6 +241,7 @@ class MistralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj( return self.o_proj(

View File

@ -38,6 +38,7 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -213,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -256,7 +258,12 @@ class MixtralAttention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -266,6 +273,7 @@ class MixtralAttention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -281,6 +289,7 @@ class MixtralAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -130,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
@ -163,7 +165,12 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -173,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
key=qkv[:, 1], key=qkv[:, 1],
value=qkv[:, 2], value=qkv[:, 2],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -187,6 +195,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -18,6 +18,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -137,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False # in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load( self.dense = TensorParallelRowLinear.load(
@ -186,7 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -194,6 +201,7 @@ class FlashPhiAttention(torch.nn.Module):
query=query, query=query,
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_scales=self.kv_scales,
kv_cache=kv_cache, kv_cache=kv_cache,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
@ -209,6 +217,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -16,6 +16,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -84,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
@ -126,7 +129,12 @@ class Qwen2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -136,6 +144,7 @@ class Qwen2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -151,6 +160,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -12,6 +12,7 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -158,6 +159,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -198,7 +200,12 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -208,6 +215,7 @@ class FlashRWAttention(torch.nn.Module):
key=kv[:, 0], key=kv[:, 0],
value=kv[:, 1], value=kv[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -222,6 +230,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -276,6 +285,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -311,7 +321,10 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
kv_cache.store( kv_cache.store(
key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots key=kv[:, :, 0].contiguous(),
value=kv[:, :, 1].contiguous(),
slots=slots,
kv_scales=self.kv_scales,
) )
# Prefill # Prefill
@ -322,6 +335,7 @@ class FlashRWLargeAttention(torch.nn.Module):
key=kv[:, :, 0], key=kv[:, :, 0],
value=kv[:, :, 1], value=kv[:, :, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -336,6 +350,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.dense( return self.dense(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
@ -257,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros( self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device self.num_heads, dtype=torch.int32, device=weights.device
) )
@ -282,7 +284,12 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) kv_cache.store(
key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -292,6 +299,7 @@ class FlashMQAttention(torch.nn.Module):
key=key_value[:, 0], key=key_value[:, 0],
value=key_value[:, 1], value=key_value[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -306,6 +314,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -38,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
@ -188,6 +189,7 @@ class Starcoder2Attention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -231,7 +233,12 @@ class Starcoder2Attention(torch.nn.Module):
else: else:
kv_to_cache = kv kv_to_cache = kv
kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) kv_cache.store(
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
@ -241,6 +248,7 @@ class Starcoder2Attention(torch.nn.Module):
key=kv_to_cache[:, 0], key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1], value=kv_to_cache[:, 1],
kv_cache=kv_cache, kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen, seqlen=seqlen,
block_tables=block_tables, block_tables=block_tables,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@ -256,6 +264,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables, block_tables,
seqlen, seqlen,
max_s, max_s,
kv_scales=self.kv_scales,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -2283,6 +2283,7 @@ class FlashCausalLM(Model):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
page_size=BLOCK_SIZE, page_size=BLOCK_SIZE,
kv_cache_dtype=self.kv_cache_dtype,
dtype=self.dtype, dtype=self.dtype,
window_left=self.sliding_window, window_left=self.sliding_window,
) )

View File

@ -207,7 +207,9 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)