updated doc
This commit is contained in:
parent
1e6e7db02e
commit
15b351b4a9
|
@ -21,8 +21,8 @@ E4M3 offers higher precision for representing floating point numbers. However, d
|
|||
|
||||
## Current Hardware Support
|
||||
|
||||
* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
|
||||
* AMD GPUs: Supports FP8E4M3.
|
||||
* Nvidia GPUs: Supports both FP8E4M3 (fp8) and FP8E5M2 (fp8_e5m2).
|
||||
* AMD GPUs: Supports FP8E4M3FNUZ (fp8).
|
||||
|
||||
## FP8 E5M2 KV Cache
|
||||
Example usage:
|
||||
|
|
|
@ -91,6 +91,8 @@ Options:
|
|||
## KV_CACHE_DTYPE
|
||||
```shell
|
||||
--kv-cache-dtype <KV_CACHE_DTYPE>
|
||||
Specify the data type for KV cache. By default, it uses the model's data type. CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'. If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided. If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
|
||||
|
||||
[env: KV_CACHE_DTYPE=]
|
||||
[possible values: fp8, fp8_e5m2]
|
||||
|
||||
|
|
|
@ -729,6 +729,7 @@ class FlashCausalLM(Model):
|
|||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
@ -737,6 +738,8 @@ class FlashCausalLM(Model):
|
|||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype else dtype
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -854,7 +857,7 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
max_bt = batch.max_blocks
|
||||
|
@ -873,7 +876,7 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
|
||||
# Calculate the number of blocks that can be allocated with the free memory
|
||||
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
|
||||
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
|
||||
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
|
||||
|
||||
|
@ -893,7 +896,7 @@ class FlashCausalLM(Model):
|
|||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
|
|
|
@ -84,8 +84,9 @@ class FlashLlama(FlashCausalLM):
|
|||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
kv_cache_dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue