updated doc
This commit is contained in:
parent
1e6e7db02e
commit
15b351b4a9
|
@ -21,8 +21,8 @@ E4M3 offers higher precision for representing floating point numbers. However, d
|
||||||
|
|
||||||
## Current Hardware Support
|
## Current Hardware Support
|
||||||
|
|
||||||
* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
|
* Nvidia GPUs: Supports both FP8E4M3 (fp8) and FP8E5M2 (fp8_e5m2).
|
||||||
* AMD GPUs: Supports FP8E4M3.
|
* AMD GPUs: Supports FP8E4M3FNUZ (fp8).
|
||||||
|
|
||||||
## FP8 E5M2 KV Cache
|
## FP8 E5M2 KV Cache
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
|
@ -91,6 +91,8 @@ Options:
|
||||||
## KV_CACHE_DTYPE
|
## KV_CACHE_DTYPE
|
||||||
```shell
|
```shell
|
||||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||||
|
Specify the data type for KV cache. By default, it uses the model's data type. CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'. If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided. If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
|
||||||
|
|
||||||
[env: KV_CACHE_DTYPE=]
|
[env: KV_CACHE_DTYPE=]
|
||||||
[possible values: fp8, fp8_e5m2]
|
[possible values: fp8, fp8_e5m2]
|
||||||
|
|
||||||
|
|
|
@ -729,6 +729,7 @@ class FlashCausalLM(Model):
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
|
@ -737,6 +738,8 @@ class FlashCausalLM(Model):
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
|
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype else dtype
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -854,7 +857,7 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
|
@ -873,7 +876,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||||
# Calculate the number of blocks that can be allocated with the free memory
|
# Calculate the number of blocks that can be allocated with the free memory
|
||||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||||
|
|
||||||
|
@ -893,7 +896,7 @@ class FlashCausalLM(Model):
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.dtype,
|
self.kv_cache_dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -84,8 +84,9 @@ class FlashLlama(FlashCausalLM):
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_key_value_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
kv_cache_dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue