Fixing prefix caching for flashdecoding.

This commit is contained in:
Nicolas Patry 2024-08-27 14:23:51 +02:00
parent 7f1816a4e1
commit 65b94a69bd
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 97 additions and 47 deletions

View File

@ -9,26 +9,48 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(self, input_lengths): def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device device = self.input_lengths.device
shape = self.input_lengths.shape shape = self.input_lengths.shape
cu_seqlen_q = torch.arange( if cu_seqlen_q is None:
shape[0] + 1, cu_seqlen_q = torch.arange(
device=device, shape[0] + 1,
dtype=torch.int32, device=device,
) dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# if max_q is not None and max_q < 1000 and max_q > 1:
# import ipdb;ipdb.set_trace()
# cuda graphs don't like this and this is necessary to clamp within mistral # cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping # Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0 # cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max): def clamp(self, max):
# Flash decoding doesn't need to clamp # Flash decoding doesn't need to clamp
@ -39,6 +61,11 @@ else:
@dataclass @dataclass
class Seqlen: class Seqlen:
input_lengths: torch.Tensor input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max): def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max)) return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -222,12 +222,10 @@ if ATTENTION == "flashinfer":
def attention( def attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
@ -250,12 +248,10 @@ elif V2:
def attention( def attention(
q, q,
k,
v,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
@ -266,17 +262,17 @@ elif V2:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_k,
None, None,
None, None,
block_tables,
None, None,
None, seqlen.max_q,
max_s, seqlen.max_k,
max_s,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,

View File

@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
Seqlen,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
cu_seqlen_prefill, seqlen,
max_s, block_tables,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
input_lengths, seqlen,
max_s, max_s,
) )
@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
): ):
@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i], kv_cache[i],
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
adapter_data, adapter_data,
) )
@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor, block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, seqlen: Seqlen,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache, kv_cache,
block_tables, block_tables,
slots, slots,
input_lengths, seqlen,
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,

View File

@ -283,7 +283,6 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
# Position ids # Position ids
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
request_position_ids = torch.arange( request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32 prefix_len, orig_input_length, dtype=torch.int32
) )
@ -1158,8 +1157,15 @@ class FlashCausalLM(Model):
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths_tensor, "input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
} }
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
@ -1202,7 +1208,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1211,7 +1217,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1219,7 +1231,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths_tensor, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
@ -1358,18 +1370,26 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths) prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=torch.tensor( cu_seqlen_prefill=cu_seqlen_prefill,
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=None,
input_lengths=input_lengths, seqlen=seqlen,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
@ -1449,7 +1469,8 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor # TODO
# input_lengths = input_lengths + prefix_lens_tensor
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,
@ -1464,7 +1485,14 @@ class FlashCausalLM(Model):
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor, prefix_lens_tensor=prefix_lens_tensor,
): ):
input_lengths = Seqlen(input_lengths=input_lengths) max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1472,7 +1500,7 @@ class FlashCausalLM(Model):
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, seqlen=seqlen,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,