Support `e4m3fn` KV cache (#2655)
* Support `e4m3fn` KV cache * Make check more obvious
This commit is contained in:
parent
a6a0c97ed9
commit
5bbe1ce028
|
@ -93,10 +93,10 @@ Options:
|
||||||
## KV_CACHE_DTYPE
|
## KV_CACHE_DTYPE
|
||||||
```shell
|
```shell
|
||||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
|
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA
|
||||||
|
|
||||||
[env: KV_CACHE_DTYPE=]
|
[env: KV_CACHE_DTYPE=]
|
||||||
[possible values: fp8_e5m2]
|
[possible values: fp8_e4m3fn, fp8_e5m2]
|
||||||
|
|
||||||
```
|
```
|
||||||
## TRUST_REMOTE_CODE
|
## TRUST_REMOTE_CODE
|
||||||
|
|
|
@ -307,6 +307,9 @@ impl std::fmt::Display for Dtype {
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum KVCacheDtype {
|
enum KVCacheDtype {
|
||||||
|
#[clap(name = "fp8_e4m3fn")]
|
||||||
|
Fp8e4m3fn,
|
||||||
|
|
||||||
#[clap(name = "fp8_e5m2")]
|
#[clap(name = "fp8_e5m2")]
|
||||||
Fp8e5m2,
|
Fp8e5m2,
|
||||||
}
|
}
|
||||||
|
@ -314,6 +317,9 @@ enum KVCacheDtype {
|
||||||
impl std::fmt::Display for KVCacheDtype {
|
impl std::fmt::Display for KVCacheDtype {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
KVCacheDtype::Fp8e4m3fn => {
|
||||||
|
write!(f, "fp8_e4m3fn")
|
||||||
|
}
|
||||||
KVCacheDtype::Fp8e5m2 => {
|
KVCacheDtype::Fp8e5m2 => {
|
||||||
write!(f, "fp8_e5m2")
|
write!(f, "fp8_e5m2")
|
||||||
}
|
}
|
||||||
|
@ -424,7 +430,7 @@ struct Args {
|
||||||
|
|
||||||
/// Specify the dtype for the key-value cache. When this option is not provided,
|
/// Specify the dtype for the key-value cache. When this option is not provided,
|
||||||
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
|
||||||
/// the only supported value is `fp8_e5m2` on CUDA.
|
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
kv_cache_dtype: Option<KVCacheDtype>,
|
kv_cache_dtype: Option<KVCacheDtype>,
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ class Dtype(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class KVCacheDtype(str, Enum):
|
class KVCacheDtype(str, Enum):
|
||||||
|
fp8_e4m3fn = "fp8_e4m3fn"
|
||||||
fp8_e5m2 = "fp8_e5m2"
|
fp8_e5m2 = "fp8_e5m2"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,11 +24,11 @@ class KVCache:
|
||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
|
|
||||||
if dtype == torch.float8_e5m2 and (
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
||||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
|
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
||||||
)
|
)
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
@ -105,8 +105,8 @@ class KVCache:
|
||||||
# TODO: add scale
|
# TODO: add scale
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype == torch.float8_e5m2:
|
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
# Torch index_put does not support float8_e5m2 yet, so
|
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
|
||||||
# put as raw data instead.
|
# put as raw data instead.
|
||||||
key_cache = key_cache.view(torch.uint8)
|
key_cache = key_cache.view(torch.uint8)
|
||||||
value_cache = value_cache.view(torch.uint8)
|
value_cache = value_cache.view(torch.uint8)
|
||||||
|
|
|
@ -421,6 +421,8 @@ def get_model(
|
||||||
|
|
||||||
if kv_cache_dtype is None:
|
if kv_cache_dtype is None:
|
||||||
kv_cache_dtype = dtype
|
kv_cache_dtype = dtype
|
||||||
|
elif kv_cache_dtype == "fp8_e4m3fn":
|
||||||
|
kv_cache_dtype = torch.float8_e4m3fn
|
||||||
elif kv_cache_dtype == "fp8_e5m2":
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
kv_cache_dtype = torch.float8_e5m2
|
kv_cache_dtype = torch.float8_e5m2
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue