From 15b351b4a9b31a87c4bd19e8554ed5a45bd67830 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 25 Jun 2024 15:35:49 +0000 Subject: [PATCH] updated doc --- docs/source/basic_tutorials/fp8_kv_cache.md | 4 ++-- docs/source/basic_tutorials/launcher.md | 2 ++ server/text_generation_server/models/flash_causal_lm.py | 9 ++++++--- server/text_generation_server/models/flash_llama.py | 3 ++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/source/basic_tutorials/fp8_kv_cache.md b/docs/source/basic_tutorials/fp8_kv_cache.md index af9a072b..f4884d3f 100644 --- a/docs/source/basic_tutorials/fp8_kv_cache.md +++ b/docs/source/basic_tutorials/fp8_kv_cache.md @@ -21,8 +21,8 @@ E4M3 offers higher precision for representing floating point numbers. However, d ## Current Hardware Support -* Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2. -* AMD GPUs: Supports FP8E4M3. +* Nvidia GPUs: Supports both FP8E4M3 (fp8) and FP8E5M2 (fp8_e5m2). +* AMD GPUs: Supports FP8E4M3FNUZ (fp8). ## FP8 E5M2 KV Cache Example usage: diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index faa97c5c..bf6a6934 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -91,6 +91,8 @@ Options: ## KV_CACHE_DTYPE ```shell --kv-cache-dtype + Specify the data type for KV cache. By default, it uses the model's data type. CUDA 11.8+ supports `fp8(fp8_e4m3)` and 'fp8_e5m2', while ROCm (AMD GPU) supports `fp8(fp8_e4m3fnuz)'. If 'fp8' is chosen, a model checkpoint with scales for the KV cache should be provided. If not provided, the KV cache scaling factors default to 1.0, which may impact accuracy." + [env: KV_CACHE_DTYPE=] [possible values: fp8, fp8_e5m2] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d3710..470084c9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -729,6 +729,7 @@ class FlashCausalLM(Model): rank: int = 0, world_size: int = 1, sliding_window: Optional[int] = None, + kv_cache_dtype: Optional[torch.dtype] = None, ): self.num_layers = num_layers self.num_kv_heads = num_kv_heads @@ -737,6 +738,8 @@ class FlashCausalLM(Model): self.cuda_graphs = {} self.kv_cache = [] + self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype else dtype + super(FlashCausalLM, self).__init__( model=model, tokenizer=tokenizer, @@ -854,7 +857,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) max_bt = batch.max_blocks @@ -873,7 +876,7 @@ class FlashCausalLM(Model): # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.dtype).element_size() + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size @@ -893,7 +896,7 @@ class FlashCausalLM(Model): self.num_layers, self.num_kv_heads, self.head_size, - self.dtype, + self.kv_cache_dtype, self.device, ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 85d8ad10..37628513 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -84,8 +84,9 @@ class FlashLlama(FlashCausalLM): num_layers=len(model.model.layers), num_kv_heads=model.model.num_key_value_heads, head_size=model.model.head_size, - dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype, + dtype=dtype, device=device, rank=rank, world_size=world_size, + kv_cache_dtype=torch.uint8 if "fp8" in kv_cache_dtype else dtype, )