From fa14d71ac82dd02dc9325f1c8fea2103e1b6b2c2 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Wed, 18 Dec 2024 14:55:53 +0000 Subject: [PATCH] add fp8 kv cache for rocm --- .../layers/attention/kv_cache.py | 50 +++++++++-- .../layers/attention/rocm.py | 18 ++-- .../custom_modeling/flash_llama_modeling.py | 12 ++- .../custom_modeling/flash_mixtral_modeling.py | 86 ++++++++++++++----- 4 files changed, 122 insertions(+), 44 deletions(-) diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 93d74732..67105057 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -52,13 +52,22 @@ class KVCache: device: torch.device, ): """Construct the key-value cache for a layer.""" + if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: + if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not ( + ATTENTION == "paged" and SYSTEM == "rocm" + ): + raise ValueError( + "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM" + ) + if SYSTEM == "rocm" and dtype == torch.float8_e5m2: + raise ValueError( + "float8_e5m2 FP8 KV cache is not supported on AMD Rocm" + ) - if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and ( - ATTENTION != "flashinfer" or SYSTEM != "cuda" - ): - raise ValueError( - "FP8 KV cache is currently only supported for flashinfer on CUDA" - ) + self.kv_cache_dtype_str = "auto" + if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn: + self.kv_cache_dtype_str = "fp8" + dtype = torch.uint8 element_size = torch.tensor([], dtype=dtype).element_size() if SYSTEM == "ipex" and device.type == "xpu": @@ -120,6 +129,16 @@ class KVCache: "Using FP8 KV cache scales", ) return True + elif ( + self.kv_cache_dtype_str == "fp8" + and ATTENTION == "paged" + and SYSTEM == "rocm" + ): + log_once( + logger.info, + "Using FP8 KV cache scales", + ) + return True else: # We have scales, but not the correct FP8 cache type, so warn once. log_once( @@ -158,7 +177,7 @@ class KVCache: key_cache = self.kv_cache[0] value_cache = self.kv_cache[1] - if self.can_scale(kv_scales): + if self.can_scale(kv_scales) and SYSTEM == "cuda": if kv_scales.key_scale_cpu != 1.0: key = fp8_quantize( key.float(), @@ -188,7 +207,16 @@ class KVCache: key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: - paged_reshape_and_cache(key, value, key_cache, value_cache, slots) + paged_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + self.kv_cache_dtype_str, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, + ) def paged_reshape_and_cache( @@ -197,7 +225,11 @@ def paged_reshape_and_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slots: torch.Tensor, + kv_cache_dtype: str = "auto", + k_scale: float = 1.0, + v_scale: float = 1.0, ): + if SYSTEM == "cuda": try: import attention_kernels @@ -216,7 +248,7 @@ def paged_reshape_and_cache( f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" ) ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0 + key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale ) elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 0cfac25b..bc790f06 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -119,9 +119,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache.kv_cache_dtype_str, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, ) else: # Run PagedAttention V2. @@ -154,9 +154,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache.kv_cache_dtype_str, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, ) else: ops.paged_attention_rocm( @@ -174,9 +174,9 @@ def paged_attention( block_size, max_s, None, - "auto", - 1.0, - 1.0, + kv_cache.kv_cache_dtype_str, + kv_scales.key_scale_cpu, + kv_scales.value_scale_cpu, None, _PARTITION_SIZE, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 10309006..53df59df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -398,10 +398,16 @@ class LlamaMLP(nn.Module): return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj( - self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) + out = torch.empty( + output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device ) + ops.silu_and_mul(out, gate_up_states) + return self.down_proj(out, adapter_data) + # gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + # return self.down_proj( + # self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + # ) class FlashLlamaLayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a45dd1e6..1bc6c7d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -520,28 +520,68 @@ class FlashMixtralForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) - hidden_states = self.model( - input_ids, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - true_max_s, - prefill_cache_indices, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) + if ( + torch.distributed.get_rank() == 0 + and input_ids.shape[0] == 262144 + and cu_seqlen_prefill is not None + ): + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + + prof.export_chrome_trace("/tgi/trace_mistral_prefill.json") + else: + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) return logits