From 65b94a69bddb35257fc3d3b4c6e64d7f7a7d61d7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Aug 2024 14:23:51 +0200 Subject: [PATCH] Fixing prefix caching for flashdecoding. --- .../layers/attention/common.py | 41 +++++++++++--- .../layers/attention/cuda.py | 26 ++++----- .../custom_modeling/flash_llama_modeling.py | 23 ++++---- .../models/flash_causal_lm.py | 54 ++++++++++++++----- 4 files changed, 97 insertions(+), 47 deletions(-) diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index f162230c..0faec0ea 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,26 +9,48 @@ if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int - def __init__(self, input_lengths): + def __init__( + self, + input_lengths, + prefix_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): self.input_lengths = input_lengths + self.prefix_lengths = prefix_lengths device = self.input_lengths.device shape = self.input_lengths.shape - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # if max_q is not None and max_q < 1000 and max_q > 1: + # import ipdb;ipdb.set_trace() + # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + total = self.input_lengths + self.prefix_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k def clamp(self, max): # Flash decoding doesn't need to clamp @@ -39,6 +61,11 @@ else: @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor + cu_seqlen_q: torch.Tensor + max_q: int + max_k: int def clamp(self, max): + raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 40d71e2d..4b588b5c 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -222,12 +222,10 @@ if ATTENTION == "flashinfer": def attention( q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -250,12 +248,10 @@ elif V2: def attention( q, - k, - v, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -266,17 +262,17 @@ elif V2: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, None, None, + block_tables, None, - None, - max_s, - max_s, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3253d2dc..5b228f9f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,6 +32,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2c218eb..265255dd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -283,7 +283,6 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(tokenized_input) # Position ids - print(f"Prefix {prefix_len} - Orig {orig_input_length}") request_position_ids = torch.arange( prefix_len, orig_input_length, dtype=torch.int32 ) @@ -1158,8 +1157,15 @@ class FlashCausalLM(Model): "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths_tensor, + "prefix_lengths": prefix_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1202,7 +1208,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1211,7 +1217,13 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1219,7 +1231,7 @@ class FlashCausalLM(Model): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_tensor, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1358,18 +1370,26 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - input_lengths = Seqlen(input_lengths=input_lengths) + prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlen_prefill=torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ), + cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, block_tables=None, - input_lengths=input_lengths, + seqlen=seqlen, slots=slots, max_s=seqlen, lm_head_indices=None, @@ -1449,7 +1469,8 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor + # TODO + # input_lengths = input_lengths + prefix_lens_tensor if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1464,7 +1485,14 @@ class FlashCausalLM(Model): prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1472,7 +1500,7 @@ class FlashCausalLM(Model): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices,