From 2358c2bb54cf60a1596267192451a46a24f03e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 4 Oct 2024 17:51:48 +0200 Subject: [PATCH] Add basic FP8 KV cache support (#2603) * Add basic FP8 KV cache support This change adds rudimentary FP8 KV cache support. The support is enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so uses this type for the KV cache. However support is still limited: * Only the `fp8_e5m2` type is supported. * The KV cache layout is the same as `float16`/`bfloat16` (HND). * The FP8 KV cache is only supported for FlashInfer. * Loading of scales is not yet supported. * Fix Cargo.toml --- Cargo.lock | 14 +- docs/source/reference/launcher.md | 9 + integration-tests/conftest.py | 9 +- .../test_flash_llama_fp8_kv_cache.json | 104 +++++ ...t_flash_llama_fp8_kv_cache_all_params.json | 57 +++ .../test_flash_llama_fp8_kv_cache_load.json | 418 ++++++++++++++++++ .../models/test_flash_llama_fp8_kv_cache.py | 77 ++++ launcher/src/main.rs | 30 ++ server/text_generation_server/cli.py | 7 + .../layers/attention/__init__.py | 18 +- .../layers/attention/cuda.py | 8 + .../layers/attention/ipex.py | 9 + .../layers/attention/kv_cache.py | 121 +++++ .../layers/attention/rocm.py | 8 + .../text_generation_server/models/__init__.py | 31 ++ .../custom_modeling/flash_cohere_modeling.py | 11 +- .../custom_modeling/flash_dbrx_modeling.py | 11 +- .../flash_deepseek_v2_modeling.py | 11 +- .../custom_modeling/flash_gemma2_modeling.py | 11 +- .../custom_modeling/flash_gemma_modeling.py | 11 +- .../custom_modeling/flash_gpt2_modeling.py | 11 +- .../custom_modeling/flash_gptj_modeling.py | 11 +- .../custom_modeling/flash_llama_modeling.py | 15 +- .../custom_modeling/flash_mistral_modeling.py | 13 +- .../custom_modeling/flash_mixtral_modeling.py | 13 +- .../custom_modeling/flash_neox_modeling.py | 11 +- .../custom_modeling/flash_phi_modeling.py | 11 +- .../custom_modeling/flash_qwen2_modeling.py | 13 +- .../custom_modeling/flash_rw_modeling.py | 27 +- .../flash_santacoder_modeling.py | 13 +- .../flash_starcoder2_modeling.py | 13 +- .../models/flash_causal_lm.py | 77 +--- server/text_generation_server/server.py | 4 + 33 files changed, 1015 insertions(+), 192 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json create mode 100644 integration-tests/models/test_flash_llama_fp8_kv_cache.py create mode 100644 server/text_generation_server/layers/attention/kv_cache.py diff --git a/Cargo.lock b/Cargo.lock index 27499cd4..5e85e384 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4177,7 +4177,7 @@ dependencies = [ [[package]] name = "text-generation-backends-trtllm" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4200,7 +4200,7 @@ dependencies = [ [[package]] name = "text-generation-benchmark" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "average", "clap 4.5.18", @@ -4220,7 +4220,7 @@ dependencies = [ [[package]] name = "text-generation-client" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "async-trait", "base64 0.22.1", @@ -4238,7 +4238,7 @@ dependencies = [ [[package]] name = "text-generation-launcher" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "clap 4.5.18", "ctrlc", @@ -4259,7 +4259,7 @@ dependencies = [ [[package]] name = "text-generation-router" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4308,7 +4308,7 @@ dependencies = [ [[package]] name = "text-generation-router-v2" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "async-stream", "async-trait", @@ -4357,7 +4357,7 @@ dependencies = [ [[package]] name = "text-generation-router-v3" -version = "2.3.1-dev0" +version = "2.3.2-dev0" dependencies = [ "async-stream", "async-trait", diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index c8d2a4c6..b1abd1ee 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -89,6 +89,15 @@ Options: [env: DTYPE=] [possible values: float16, bfloat16] +``` +## KV_CACHE_DTYPE +```shell + --kv-cache-dtype + Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA + + [env: KV_CACHE_DTYPE=] + [possible values: fp8_e5m2] + ``` ## TRUST_REMOTE_CODE ```shell diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index eb55ebb9..4c8c929f 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -336,6 +336,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, @@ -375,6 +376,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if kv_cache_dtype is not None: + args.append("--kv-cache-dtype") + args.append(kv_cache_dtype) if revision is not None: args.append("--revision") args.append(revision) @@ -434,6 +438,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, revision: Optional[str] = None, max_input_length: Optional[int] = None, max_batch_prefill_tokens: Optional[int] = None, @@ -456,6 +461,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if kv_cache_dtype is not None: + args.append("--kv-cache-dtype") + args.append(kv_cache_dtype) if revision is not None: args.append("--revision") args.append(revision) @@ -589,7 +597,6 @@ def generate_multi(): max_new_tokens: int, seed: Optional[int] = None, ) -> List[Response]: - import numpy as np arange = np.arange(len(prompts)) diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json new file mode 100644 index 00000000..c55dd593 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json @@ -0,0 +1,104 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -5.6328125, + "text": "What" + }, + { + "id": 374, + "logprob": -1.2265625, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.1015625, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.8085938, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.0439453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18682, + "logprob": -2.1992188, + "special": false, + "text": " Deep" + }, + { + "id": 6975, + "logprob": -0.079956055, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.2763672, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.37548828, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -1.4628906, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.02885437, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.2565918, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0063438416, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -1.3056641, + "special": false, + "text": " that" + }, + { + "id": 374, + "logprob": -1.6035156, + "special": false, + "text": " is" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that is" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json new file mode 100644 index 00000000..d06d6e56 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json @@ -0,0 +1,57 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "eos_token", + "generated_tokens": 3, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 374, + "logprob": -22.96875, + "text": " is" + }, + { + "id": 5655, + "logprob": -10.71875, + "text": " deep" + }, + { + "id": 6975, + "logprob": -2.6992188, + "text": " learning" + }, + { + "id": 30, + "logprob": -4.8398438, + "text": "?" + } + ], + "seed": 0, + "tokens": [ + { + "id": 720, + "logprob": -0.4411621, + "special": false, + "text": " \n" + }, + { + "id": 220, + "logprob": -0.35864258, + "special": false, + "text": " " + }, + { + "id": 128001, + "logprob": 0.0, + "special": true, + "text": "<|end_of_text|>" + } + ], + "top_tokens": null + }, + "generated_text": "What is deep learning? \n " +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json new file mode 100644 index 00000000..46670819 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json @@ -0,0 +1,418 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -5.6328125, + "text": "What" + }, + { + "id": 374, + "logprob": -1.2265625, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.1015625, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.8085938, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.0439453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18682, + "logprob": -2.1992188, + "special": false, + "text": " Deep" + }, + { + "id": 6975, + "logprob": -0.07897949, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.27734375, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.37402344, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -1.4511719, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.02909851, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.25854492, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0061798096, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -1.3046875, + "special": false, + "text": " that" + }, + { + "id": 374, + "logprob": -1.5537109, + "special": false, + "text": " is" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that is" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -5.6328125, + "text": "What" + }, + { + "id": 374, + "logprob": -1.2265625, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.1015625, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.8085938, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.0439453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18682, + "logprob": -2.1992188, + "special": false, + "text": " Deep" + }, + { + "id": 6975, + "logprob": -0.07897949, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.27734375, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.37402344, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -1.4511719, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.02909851, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.25854492, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0061798096, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -1.3046875, + "special": false, + "text": " that" + }, + { + "id": 374, + "logprob": -1.5537109, + "special": false, + "text": " is" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that is" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -5.6328125, + "text": "What" + }, + { + "id": 374, + "logprob": -1.2265625, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.1015625, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.8085938, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.0439453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18682, + "logprob": -2.1992188, + "special": false, + "text": " Deep" + }, + { + "id": 6975, + "logprob": -0.07897949, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.27734375, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.37402344, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -1.4511719, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.02909851, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.25854492, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0061798096, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -1.3046875, + "special": false, + "text": " that" + }, + { + "id": 374, + "logprob": -1.5537109, + "special": false, + "text": " is" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that is" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 3923, + "logprob": -5.6328125, + "text": "What" + }, + { + "id": 374, + "logprob": -1.2265625, + "text": " is" + }, + { + "id": 5655, + "logprob": -9.1015625, + "text": " deep" + }, + { + "id": 6975, + "logprob": -1.8085938, + "text": " learning" + }, + { + "id": 30, + "logprob": -1.0439453, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 18682, + "logprob": -2.1992188, + "special": false, + "text": " Deep" + }, + { + "id": 6975, + "logprob": -0.07897949, + "special": false, + "text": " learning" + }, + { + "id": 374, + "logprob": -0.27734375, + "special": false, + "text": " is" + }, + { + "id": 264, + "logprob": -0.37402344, + "special": false, + "text": " a" + }, + { + "id": 27084, + "logprob": -1.4511719, + "special": false, + "text": " subset" + }, + { + "id": 315, + "logprob": -0.02909851, + "special": false, + "text": " of" + }, + { + "id": 5780, + "logprob": -0.25854492, + "special": false, + "text": " machine" + }, + { + "id": 6975, + "logprob": -0.0061798096, + "special": false, + "text": " learning" + }, + { + "id": 430, + "logprob": -1.3046875, + "special": false, + "text": " that" + }, + { + "id": 374, + "logprob": -1.5537109, + "special": false, + "text": " is" + } + ], + "top_tokens": null + }, + "generated_text": " Deep learning is a subset of machine learning that is" + } +] diff --git a/integration-tests/models/test_flash_llama_fp8_kv_cache.py b/integration-tests/models/test_flash_llama_fp8_kv_cache.py new file mode 100644 index 00000000..05e9f0dd --- /dev/null +++ b/integration-tests/models/test_flash_llama_fp8_kv_cache.py @@ -0,0 +1,77 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_fp8_kv_cache_handle(launcher): + with launcher( + "meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle): + await flash_llama_fp8_kv_cache_handle.health(300) + return flash_llama_fp8_kv_cache_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot): + response = await flash_llama_fp8_kv_cache.generate( + "What is deep learning?", max_new_tokens=10, decoder_input_details=True + ) + + assert ( + response.generated_text + == " Deep learning is a subset of machine learning that is" + ) + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_kv_cache_all_params( + flash_llama_fp8_kv_cache, response_snapshot +): + response = await flash_llama_fp8_kv_cache.generate( + "What is deep learning?", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_kv_cache_load( + flash_llama_fp8_kv_cache, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_fp8_kv_cache, "What is deep learning?", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert ( + responses[0].generated_text + == " Deep learning is a subset of machine learning that is" + ) + assert all( + [r.generated_text == responses[0].generated_text for r in responses] + ), f"Different messages : {[r.generated_text for r in responses]}" + assert responses == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index aba497d6..214adcdc 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -301,6 +301,22 @@ impl std::fmt::Display for Dtype { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum KVCacheDtype { + #[clap(name = "fp8_e5m2")] + Fp8e5m2, +} + +impl std::fmt::Display for KVCacheDtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + KVCacheDtype::Fp8e5m2 => { + write!(f, "fp8_e5m2") + } + } + } +} + #[derive(Clone, Copy, Debug, ValueEnum)] enum RopeScaling { Linear, @@ -402,6 +418,12 @@ struct Args { #[clap(long, env, value_enum)] dtype: Option, + /// Specify the dtype for the key-value cache. When this option is not provided, + /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently + /// the only supported value is `fp8_e5m2` on CUDA. + #[clap(long, env, value_enum)] + kv_cache_dtype: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been /// contributed in a newer revision. @@ -670,6 +692,7 @@ fn shard_manager( quantize: Option, speculate: Option, dtype: Option, + kv_cache_dtype: Option, trust_remote_code: bool, uds_path: String, rank: usize, @@ -743,6 +766,11 @@ fn shard_manager( shard_args.push(dtype.to_string()) } + if let Some(kv_cache_dtype) = kv_cache_dtype { + shard_args.push("--kv-cache-dtype".to_string()); + shard_args.push(kv_cache_dtype.to_string()) + } + // Model optional revision if let Some(revision) = revision { shard_args.push("--revision".to_string()); @@ -1299,6 +1327,7 @@ fn spawn_shards( let otlp_service_name = args.otlp_service_name.clone(); let speculate = args.speculate; let dtype = args.dtype; + let kv_cache_dtype = args.kv_cache_dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; let disable_custom_kernels = args.disable_custom_kernels; @@ -1317,6 +1346,7 @@ fn spawn_shards( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, uds_path, rank, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 10aa3a3b..db390234 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -30,6 +30,10 @@ class Dtype(str, Enum): bloat16 = "bfloat16" +class KVCacheDtype(str, Enum): + fp8_e5m2 = "fp8_e5m2" + + @app.command() def serve( model_id: str, @@ -38,6 +42,7 @@ def serve( quantize: Optional[Quantization] = None, speculate: Optional[int] = None, dtype: Optional[Dtype] = None, + kv_cache_dtype: Optional[KVCacheDtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -97,6 +102,7 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value dtype = None if dtype is None else dtype.value + kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value if dtype is not None and quantize not in { None, "bitsandbytes", @@ -114,6 +120,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, uds_path, max_input_tokens, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index 4f2b9807..cc7f0caa 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,37 +1,40 @@ -from text_generation_server.utils.import_utils import SYSTEM import os +from text_generation_server.utils.import_utils import SYSTEM + from .common import Seqlen if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": from .cuda import ( + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, attention, paged_attention, reshape_and_cache, - SUPPORTS_WINDOWING, - PREFILL_IN_KV_CACHE, ) elif SYSTEM == "rocm": from .rocm import ( + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, attention, paged_attention, reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, ) elif SYSTEM == "ipex": from .ipex import ( + PREFILL_IN_KV_CACHE, + SUPPORTS_WINDOWING, attention, paged_attention, reshape_and_cache, - PREFILL_IN_KV_CACHE, - SUPPORTS_WINDOWING, ) else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") +# KVCache needs `reshape_and_cache`, so ensure that it is defined already. +from .kv_cache import KVCache __all__ = [ "attention", @@ -39,5 +42,6 @@ __all__ = [ "reshape_and_cache", "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", + "KVCache", "Seqlen", ] diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 51af928d..cd3ea369 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -355,3 +355,11 @@ else: # have a configuration that requires flash-attention v1, which # does not support block tables. PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2 + +__all__ = [ + "PREFILL_IN_KV_CACHE", + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "reshape_and_cache", +] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 657c90af..131c9bb0 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -80,3 +80,12 @@ def paged_attention( None, ) return out + + +__all__ = [ + "PREFILL_IN_KV_CACHE", + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "reshape_and_cache", +] diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py new file mode 100644 index 00000000..80033122 --- /dev/null +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -0,0 +1,121 @@ +from typing import Tuple + +import torch +from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.layers.attention import reshape_and_cache + + +class KVCache: + """ + Key-value cache for attention layers. + """ + + kv_cache: Tuple[torch.Tensor, torch.Tensor] + + def __init__( + self, + *, + num_blocks: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + ): + """Construct the key-value cache for a layer.""" + + if ( + dtype == torch.float8_e5m2 + and ATTENTION != "flashinfer" + and SYSTEM != "cuda" + ): + raise ValueError( + "float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" + ) + + element_size = torch.tensor([], dtype=dtype).element_size() + if SYSTEM == "ipex" and device.type == "xpu": + x = 1 + else: + x = BLOCK_SIZE // element_size + + if ATTENTION in {"flashdecoding", "flashinfer"}: + self.kv_cache = ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + elif SYSTEM == "ipex" and device == torch.device("cpu"): + self.kv_cache = ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + else: + self.kv_cache = ( + torch.zeros( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.zeros( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + + @property + def key(self): + """Get the key cache.""" + + return self.kv_cache[0] + + @property + def value(self): + """Get the value cache.""" + + return self.kv_cache[1] + + def store( + self, + *, + key: torch.Tensor, + value: torch.Tensor, + slots: torch.Tensor, + ): + """Store the key and value at the given slots.""" + + key_cache = self.kv_cache[0] + value_cache = self.kv_cache[1] + + if ATTENTION in {"flashdecoding", "flashinfer"}: + # TODO: add scale + key = key.to(key_cache.dtype) + value = value.to(value_cache.dtype) + if key_cache.dtype == torch.float8_e5m2: + # Torch index_put does not support float8_e5m2 yet, so + # put as raw data instead. + key_cache = key_cache.view(torch.uint8) + value_cache = value_cache.view(torch.uint8) + key = key.view(torch.uint8) + value = value.view(torch.uint8) + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 646a763d..01d4685a 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -306,3 +306,11 @@ elif ENGINE == "triton": else: raise RuntimeError(f"Unknown attention engine {ENGINE}") + +__all__ = [ + "PREFILL_IN_KV_CACHE", + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "reshape_and_cache", +] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4dabf71d..17eed976 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -342,6 +342,7 @@ def get_model( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, ) -> Model: @@ -403,6 +404,13 @@ def get_model( else: raise RuntimeError(f"Unknown dtype {dtype}") + if kv_cache_dtype is None: + kv_cache_dtype = dtype + elif kv_cache_dtype == "fp8_e5m2": + kv_cache_dtype = torch.float8_e5m2 + else: + raise RuntimeError(f"Unknown kv_cache_dtype: {kv_cache_dtype}") + if speculate is not None: set_speculate(speculate) else: @@ -563,6 +571,7 @@ def get_model( speculator=speculator, default_dtype=torch.bfloat16, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=DeepseekV2Config, @@ -617,6 +626,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, aliases={"transformer.wte.weight": ["lm_head.weight"]}, @@ -668,6 +678,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -703,6 +714,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -741,6 +753,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=GPTNeoXConfig, @@ -774,6 +787,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -797,6 +811,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -836,6 +851,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -859,6 +875,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, # Works better for these models default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, @@ -884,6 +901,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, # Works better for these models default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, @@ -910,6 +928,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -934,6 +953,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, # Dbrx works better in bfloat16. default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, @@ -964,6 +984,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, aliases={ "lm_head.weight": ["transformer.word_embeddings.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"], @@ -982,6 +1003,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, aliases={ "lm_head.weight": ["transformer.word_embeddings.weight"], "transformer.word_embeddings.weight": ["lm_head.weight"], @@ -1009,6 +1031,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -1033,6 +1056,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -1057,6 +1081,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -1083,6 +1108,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -1162,6 +1188,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, # XXX: Extremely important to cap resolution in order to limit @@ -1179,6 +1206,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, # Works better for these models default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, @@ -1197,6 +1225,7 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + kv_cache_dtype=kv_cache_dtype, trust_remote_code=trust_remote_code, ) else: @@ -1269,6 +1298,7 @@ def get_model_with_lora_adapters( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, max_input_tokens: int, adapter_to_index: Dict[str, int], @@ -1282,6 +1312,7 @@ def get_model_with_lora_adapters( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 30656038..d0425fec 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.utils.import_utils import SYSTEM @@ -291,15 +290,15 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=key, value=value, slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, + kv_cache.key if PREFILL_IN_KV_CACHE else key, + kv_cache.value if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, @@ -308,8 +307,8 @@ class FlashCohereAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 1137a453..b2b0cecb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -28,7 +28,6 @@ if SYSTEM != "ipex": from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, PREFILL_IN_KV_CACHE, ) @@ -330,15 +329,15 @@ class DbrxAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -347,8 +346,8 @@ class DbrxAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 88c2cf80..af77af8e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,7 +33,6 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, ) from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm @@ -321,15 +320,15 @@ class DeepseekV2Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=key, value=value, slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, + kv_cache.key if PREFILL_IN_KV_CACHE else key, + kv_cache.value if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, @@ -338,8 +337,8 @@ class DeepseekV2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 7a3d60c9..03b9b2a0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -253,15 +252,15 @@ class FlashGemma2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -273,8 +272,8 @@ class FlashGemma2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4c1be6f6..f3c46901 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,7 +28,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, PREFILL_IN_KV_CACHE, ) @@ -224,15 +223,15 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -242,8 +241,8 @@ class FlashGemmaAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 44c015cf..94a8898d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -28,7 +28,6 @@ from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -224,15 +223,15 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=key, value=value, slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, + kv_cache.key if PREFILL_IN_KV_CACHE else key, + kv_cache.value if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, @@ -241,8 +240,8 @@ class FlashGPT2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index aca97004..f0a1270e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -28,7 +28,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -186,15 +185,15 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) - reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=key, value=value, slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key, - kv_cache[1] if PREFILL_IN_KV_CACHE else value, + kv_cache.key if PREFILL_IN_KV_CACHE else key, + kv_cache.value if PREFILL_IN_KV_CACHE else value, seqlen, block_tables, self.softmax_scale, @@ -203,8 +202,8 @@ class FlashGPTJAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index c9ec70cc..fbe45d79 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,13 +27,12 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE +from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -202,7 +201,7 @@ class FlashLlamaAttention(torch.nn.Module): cos, sin, cu_seqlen_prefill, - kv_cache, + kv_cache: KVCache, block_tables, slots, seqlen, @@ -222,15 +221,15 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -239,8 +238,8 @@ class FlashLlamaAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 341a2352..8974035e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,7 +30,6 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -210,17 +209,15 @@ class MistralAttention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots - ) + kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -230,8 +227,8 @@ class MistralAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 5836d30a..e7bc8320 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,7 +37,6 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, - reshape_and_cache, ) from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm @@ -258,17 +257,15 @@ class MixtralAttention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots - ) + kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -278,8 +275,8 @@ class MixtralAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index ad4e382f..bcbea442 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,7 +29,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -165,15 +164,15 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=qkv[:, 1], value=qkv[:, 2], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], - kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], + kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1], + kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2], seqlen, block_tables, self.softmax_scale, @@ -182,8 +181,8 @@ class FlashNeoxAttention(torch.nn.Module): else: attn_output = paged_attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 2a0dc606..cb7b6ee2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,7 +9,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -188,14 +187,14 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -204,8 +203,8 @@ class FlashPhiAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 02c788d3..8185885f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,7 +8,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -128,17 +127,15 @@ class Qwen2Attention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots - ) + kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -148,8 +145,8 @@ class Qwen2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6671d85e..dac8ecf9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,7 +18,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, - reshape_and_cache, Seqlen, ) @@ -200,15 +199,15 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) + kv_cache.store(key=kv[:, 0], value=kv[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -217,8 +216,8 @@ class FlashRWAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, @@ -312,12 +311,8 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - reshape_and_cache( - kv[:, :, 0].contiguous(), - kv[:, :, 1].contiguous(), - kv_cache[0], - kv_cache[1], - slots, + kv_cache.store( + key=kv[:, :, 0].contiguous(), value=kv[:, :, 1].contiguous(), slots=slots ) # Prefill @@ -325,8 +320,8 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), - kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), + kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), + kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), seqlen, block_tables, self.softmax_scale, @@ -335,8 +330,8 @@ class FlashRWLargeAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 43eb9687..5972d436 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,7 +8,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -284,17 +283,15 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - reshape_and_cache( - key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots - ) + kv_cache.store(key=key_value[:, 0], value=key_value[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1], seqlen, block_tables, self.softmax_scale, @@ -303,8 +300,8 @@ class FlashMQAttention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 4975cf22..037238b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,7 +29,6 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, - reshape_and_cache, Seqlen, ) from text_generation_server.layers import ( @@ -233,17 +232,15 @@ class Starcoder2Attention(torch.nn.Module): else: kv_to_cache = kv - reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots - ) + kv_cache.store(key=kv_to_cache[:, 0], value=kv_to_cache[:, 1], slots=slots) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query, - kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], + kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], + kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, @@ -253,8 +250,8 @@ class Starcoder2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache[0], - kv_cache[1], + kv_cache.key, + kv_cache.value, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index bc9d44a0..33fe30a8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -46,7 +46,7 @@ from text_generation_server.models.globals import ( TGI_WIGGLE_ROOM, get_adapter_to_index, ) -from text_generation_server.layers.attention import Seqlen +from text_generation_server.layers.attention import KVCache, Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.quantization import get_loader @@ -937,6 +937,7 @@ class FlashCausalLM(Model): # Deepseek V2 uses different QK and V dims. head_size: Optional[int] = None, skip_special_tokens: bool = True, + kv_cache_dtype: Optional[torch.dtype] = None, ): self.quantize = quantize self.process_group, rank, world_size = initialize_torch_distributed() @@ -1034,6 +1035,7 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flashinfer import ( @@ -1083,61 +1085,16 @@ class FlashCausalLM(Model): ): self.kv_cache = [] empty_cache() - - element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "ipex" and device.type == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size - - if ATTENTION in {"flashdecoding", "flashinfer"}: - self.kv_cache = [ - ( - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, BLOCK_SIZE, num_heads, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - elif SYSTEM == "ipex" and device == torch.device("cpu"): - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, BLOCK_SIZE, head_size), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] - else: - self.kv_cache = [ - ( - torch.zeros( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.zeros( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + self.kv_cache = [ + KVCache( + num_blocks=num_blocks, + num_heads=num_heads, + head_size=head_size, + dtype=dtype, + device=device, + ) + for _ in range(num_layers) + ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) @@ -1258,7 +1215,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) max_bt = batch.max_blocks @@ -1277,7 +1234,7 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size @@ -1291,6 +1248,8 @@ class FlashCausalLM(Model): + batch_num_blocks ) + log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + del batch self.init_kv_cache( @@ -1298,7 +1257,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e7dfd8e4..46e342a4 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -205,6 +205,7 @@ def serve( quantize: Optional[str], speculate: Optional[int], dtype: Optional[str], + kv_cache_dtype: Optional[str], trust_remote_code: bool, uds_path: Path, max_input_tokens: int, @@ -217,6 +218,7 @@ def serve( quantize: Optional[str] = None, speculate: Optional[int] = None, dtype: Optional[str] = None, + kv_cache_dtype: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -240,6 +242,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, max_input_tokens, adapter_to_index, @@ -286,6 +289,7 @@ def serve( quantize, speculate, dtype, + kv_cache_dtype, trust_remote_code, ) )