updated doc

This commit is contained in:
Mohit Sharma 2024-06-25 15:35:49 +00:00
parent 1e6e7db02e
commit 15b351b4a9
4 changed files with 12 additions and 6 deletions

View File

@ -21,8 +21,8 @@ E4M3 offers higher precision for representing floating point numbers. However, d
## Current Hardware Support
* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2.
* AMD GPUs: Supports FP8E4M3.
* Nvidia GPUs: Supports both FP8E4M3 (fp8) and FP8E5M2 (fp8_e5m2).
* AMD GPUs: Supports FP8E4M3FNUZ (fp8).
## FP8 E5M2 KV Cache
Example usage:

View File

@ -91,6 +91,8 @@ Options:
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the data type for KV cache. By default, it uses the model's data type. CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'. If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided. If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy."
[env: KV_CACHE_DTYPE=]
[possible values: fp8, fp8_e5m2]

View File

@ -729,6 +729,7 @@ class FlashCausalLM(Model):
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
@ -737,6 +738,8 @@ class FlashCausalLM(Model):
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype else dtype
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
@ -854,7 +857,7 @@ class FlashCausalLM(Model):
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.device,
)
max_bt = batch.max_blocks
@ -873,7 +876,7 @@ class FlashCausalLM(Model):
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
@ -893,7 +896,7 @@ class FlashCausalLM(Model):
self.num_layers,
self.num_kv_heads,
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.device,
)

View File

@ -84,8 +84,9 @@ class FlashLlama(FlashCausalLM):
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
kv_cache_dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype,
)