From 5bbe1ce028c9ffc116e3c5b19d20d3b279109b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 17 Oct 2024 10:42:16 +0200 Subject: [PATCH] Support `e4m3fn` KV cache (#2655) * Support `e4m3fn` KV cache * Make check more obvious --- docs/source/reference/launcher.md | 4 ++-- launcher/src/main.rs | 8 +++++++- server/text_generation_server/cli.py | 1 + .../text_generation_server/layers/attention/kv_cache.py | 8 ++++---- server/text_generation_server/models/__init__.py | 2 ++ 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index b1abd1ee..68e487d0 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -93,10 +93,10 @@ Options: ## KV_CACHE_DTYPE ```shell --kv-cache-dtype - Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA + Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA [env: KV_CACHE_DTYPE=] - [possible values: fp8_e5m2] + [possible values: fp8_e4m3fn, fp8_e5m2] ``` ## TRUST_REMOTE_CODE diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0d7af66d..d9f569fd 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -307,6 +307,9 @@ impl std::fmt::Display for Dtype { #[derive(Clone, Copy, Debug, ValueEnum)] enum KVCacheDtype { + #[clap(name = "fp8_e4m3fn")] + Fp8e4m3fn, + #[clap(name = "fp8_e5m2")] Fp8e5m2, } @@ -314,6 +317,9 @@ enum KVCacheDtype { impl std::fmt::Display for KVCacheDtype { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + KVCacheDtype::Fp8e4m3fn => { + write!(f, "fp8_e4m3fn") + } KVCacheDtype::Fp8e5m2 => { write!(f, "fp8_e5m2") } @@ -424,7 +430,7 @@ struct Args { /// Specify the dtype for the key-value cache. When this option is not provided, /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently - /// the only supported value is `fp8_e5m2` on CUDA. + /// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA. #[clap(long, env, value_enum)] kv_cache_dtype: Option, diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index db390234..a363b33a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -31,6 +31,7 @@ class Dtype(str, Enum): class KVCacheDtype(str, Enum): + fp8_e4m3fn = "fp8_e4m3fn" fp8_e5m2 = "fp8_e5m2" diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 3960c954..7f1dd370 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -24,11 +24,11 @@ class KVCache: ): """Construct the key-value cache for a layer.""" - if dtype == torch.float8_e5m2 and ( + if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( ATTENTION != "flashinfer" or SYSTEM != "cuda" ): raise ValueError( - "float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" + "FP8 KV cache is currently only supported for flashinfer on CUDA" ) element_size = torch.tensor([], dtype=dtype).element_size() @@ -105,8 +105,8 @@ class KVCache: # TODO: add scale key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) - if key_cache.dtype == torch.float8_e5m2: - # Torch index_put does not support float8_e5m2 yet, so + if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: + # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so # put as raw data instead. key_cache = key_cache.view(torch.uint8) value_cache = value_cache.view(torch.uint8) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 019617d2..de0c66e7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -421,6 +421,8 @@ def get_model( if kv_cache_dtype is None: kv_cache_dtype = dtype + elif kv_cache_dtype == "fp8_e4m3fn": + kv_cache_dtype = torch.float8_e4m3fn elif kv_cache_dtype == "fp8_e5m2": kv_cache_dtype = torch.float8_e5m2 else: