Support `e4m3fn` KV cache (#2655)

* Support `e4m3fn` KV cache

* Make check more obvious
This commit is contained in:
Daniël de Kok 2024-10-17 10:42:16 +02:00 committed by GitHub
parent a6a0c97ed9
commit 5bbe1ce028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 7 deletions

View File

@ -93,10 +93,10 @@ Options:
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
[possible values: fp8_e4m3fn, fp8_e5m2]
```
## TRUST_REMOTE_CODE

View File

@ -307,6 +307,9 @@ impl std::fmt::Display for Dtype {
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e4m3fn")]
Fp8e4m3fn,
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
@ -314,6 +317,9 @@ enum KVCacheDtype {
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e4m3fn => {
write!(f, "fp8_e4m3fn")
}
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
@ -424,7 +430,7 @@ struct Args {
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,

View File

@ -31,6 +31,7 @@ class Dtype(str, Enum):
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"

View File

@ -24,11 +24,11 @@ class KVCache:
):
"""Construct the key-value cache for a layer."""
if dtype == torch.float8_e5m2 and (
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size()
@ -105,8 +105,8 @@ class KVCache:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)

View File

@ -421,6 +421,8 @@ def get_model(
if kv_cache_dtype is None:
kv_cache_dtype = dtype
elif kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else: