Fixing prefix caching for flashdecoding.

This commit is contained in:
Nicolas Patry 2024-08-27 14:23:51 +02:00
parent 7f1816a4e1
commit 65b94a69bd
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 97 additions and 47 deletions

View File

@ -9,26 +9,48 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(self, input_lengths):
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# if max_q is not None and max_q < 1000 and max_q > 1:
# import ipdb;ipdb.set_trace()
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
total = self.input_lengths + self.prefix_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
@ -39,6 +61,11 @@ else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -222,12 +222,10 @@ if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -250,12 +248,10 @@ elif V2:
def attention(
q,
k,
v,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens,
max_s,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -266,17 +262,17 @@ elif V2:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
None,
max_s,
max_s,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,

View File

@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
seqlen,
block_tables,
self.softmax_scale,
)
# Decode
@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
seqlen,
max_s,
)
@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
):
@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
kv_cache[i],
block_tables,
slots,
input_lengths,
seqlen,
max_s,
adapter_data,
)
@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
kv_cache,
block_tables,
slots,
input_lengths,
seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,

View File

@ -283,7 +283,6 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(tokenized_input)
# Position ids
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
)
@ -1158,8 +1157,15 @@ class FlashCausalLM(Model):
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"prefix_lengths": prefix_lengths_tensor,
}
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
@ -1202,7 +1208,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths_,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
@ -1211,7 +1217,13 @@ class FlashCausalLM(Model):
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
prefix_lengths=prefix_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1219,7 +1231,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths_tensor,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
@ -1358,18 +1370,26 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
input_lengths = Seqlen(input_lengths=input_lengths)
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None,
input_lengths=input_lengths,
seqlen=seqlen,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
@ -1449,7 +1469,8 @@ class FlashCausalLM(Model):
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = input_lengths + prefix_lens_tensor
# TODO
# input_lengths = input_lengths + prefix_lens_tensor
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
@ -1464,7 +1485,14 @@ class FlashCausalLM(Model):
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -1472,7 +1500,7 @@ class FlashCausalLM(Model):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,