Support `e4m3fn` KV cache (#2655)

* Support `e4m3fn` KV cache

* Make check more obvious
This commit is contained in:
Daniël de Kok 2024-10-17 10:42:16 +02:00 committed by GitHub
parent a6a0c97ed9
commit 5bbe1ce028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 7 deletions

View File

@ -93,10 +93,10 @@ Options:
## KV_CACHE_DTYPE ## KV_CACHE_DTYPE
```shell ```shell
--kv-cache-dtype <KV_CACHE_DTYPE> --kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=] [env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2] [possible values: fp8_e4m3fn, fp8_e5m2]
``` ```
## TRUST_REMOTE_CODE ## TRUST_REMOTE_CODE

View File

@ -307,6 +307,9 @@ impl std::fmt::Display for Dtype {
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype { enum KVCacheDtype {
#[clap(name = "fp8_e4m3fn")]
Fp8e4m3fn,
#[clap(name = "fp8_e5m2")] #[clap(name = "fp8_e5m2")]
Fp8e5m2, Fp8e5m2,
} }
@ -314,6 +317,9 @@ enum KVCacheDtype {
impl std::fmt::Display for KVCacheDtype { impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
KVCacheDtype::Fp8e4m3fn => {
write!(f, "fp8_e4m3fn")
}
KVCacheDtype::Fp8e5m2 => { KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2") write!(f, "fp8_e5m2")
} }
@ -424,7 +430,7 @@ struct Args {
/// Specify the dtype for the key-value cache. When this option is not provided, /// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA. /// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>, kv_cache_dtype: Option<KVCacheDtype>,

View File

@ -31,6 +31,7 @@ class Dtype(str, Enum):
class KVCacheDtype(str, Enum): class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2" fp8_e5m2 = "fp8_e5m2"

View File

@ -24,11 +24,11 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
if dtype == torch.float8_e5m2 and ( if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda" ATTENTION != "flashinfer" or SYSTEM != "cuda"
): ):
raise ValueError( raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA" "FP8 KV cache is currently only supported for flashinfer on CUDA"
) )
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
@ -105,8 +105,8 @@ class KVCache:
# TODO: add scale # TODO: add scale
key = key.to(key_cache.dtype) key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype) value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2: if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
# Torch index_put does not support float8_e5m2 yet, so # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead. # put as raw data instead.
key_cache = key_cache.view(torch.uint8) key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8) value_cache = value_cache.view(torch.uint8)

View File

@ -421,6 +421,8 @@ def get_model(
if kv_cache_dtype is None: if kv_cache_dtype is None:
kv_cache_dtype = dtype kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2": elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2 kv_cache_dtype = torch.float8_e5m2
else: else: