Fixing prefix caching for flashdecoding.
This commit is contained in:
parent
7f1816a4e1
commit
65b94a69bd
|
@ -9,26 +9,48 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
|
prefix_lengths: torch.Tensor
|
||||||
cu_seqlen_q: Optional[torch.Tensor]
|
cu_seqlen_q: Optional[torch.Tensor]
|
||||||
cu_seqlen_k: Optional[torch.Tensor]
|
cu_seqlen_k: Optional[torch.Tensor]
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
def __init__(self, input_lengths):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_lengths,
|
||||||
|
prefix_lengths,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=None,
|
||||||
|
max_k=None,
|
||||||
|
):
|
||||||
self.input_lengths = input_lengths
|
self.input_lengths = input_lengths
|
||||||
|
self.prefix_lengths = prefix_lengths
|
||||||
device = self.input_lengths.device
|
device = self.input_lengths.device
|
||||||
shape = self.input_lengths.shape
|
shape = self.input_lengths.shape
|
||||||
cu_seqlen_q = torch.arange(
|
if cu_seqlen_q is None:
|
||||||
shape[0] + 1,
|
cu_seqlen_q = torch.arange(
|
||||||
device=device,
|
shape[0] + 1,
|
||||||
dtype=torch.int32,
|
device=device,
|
||||||
)
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
max_q = 1
|
||||||
|
else:
|
||||||
|
assert max_q is not None
|
||||||
|
assert max_k is not None
|
||||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||||
|
# if max_q is not None and max_q < 1000 and max_q > 1:
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||||
# Although FA2 might not want the clamping
|
# Although FA2 might not want the clamping
|
||||||
# cu_seqlen_k[0] = 0
|
# cu_seqlen_k[0] = 0
|
||||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
total = self.input_lengths + self.prefix_lengths
|
||||||
|
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||||
|
|
||||||
self.cu_seqlen_q = cu_seqlen_q
|
self.cu_seqlen_q = cu_seqlen_q
|
||||||
self.cu_seqlen_k = cu_seqlen_k
|
self.cu_seqlen_k = cu_seqlen_k
|
||||||
|
self.max_q = max_q
|
||||||
|
self.max_k = max_k
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
# Flash decoding doesn't need to clamp
|
# Flash decoding doesn't need to clamp
|
||||||
|
@ -39,6 +61,11 @@ else:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seqlen:
|
class Seqlen:
|
||||||
input_lengths: torch.Tensor
|
input_lengths: torch.Tensor
|
||||||
|
prefix_lengths: torch.Tensor
|
||||||
|
cu_seqlen_q: torch.Tensor
|
||||||
|
max_q: int
|
||||||
|
max_k: int
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
|
raise NotImplementedError("Not implemented seqlen for paged")
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
|
|
@ -222,12 +222,10 @@ if ATTENTION == "flashinfer":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
@ -250,12 +248,10 @@ elif V2:
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
|
||||||
v,
|
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
@ -266,17 +262,17 @@ elif V2:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
block_tables,
|
||||||
None,
|
None,
|
||||||
None,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
max_s,
|
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
|
|
|
@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
|
Seqlen,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
|
@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
cu_seqlen_prefill,
|
seqlen,
|
||||||
max_s,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
):
|
):
|
||||||
|
@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
true_max_s: int,
|
true_max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor],
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
kv_cache[i],
|
kv_cache[i],
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
adapter_data,
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
kv_cache,
|
kv_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
slots,
|
slots,
|
||||||
input_lengths,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
true_max_s=max_s,
|
true_max_s=max_s,
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
|
|
@ -283,7 +283,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
|
|
||||||
request_position_ids = torch.arange(
|
request_position_ids = torch.arange(
|
||||||
prefix_len, orig_input_length, dtype=torch.int32
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
@ -1158,8 +1157,15 @@ class FlashCausalLM(Model):
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths_tensor,
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"prefix_lengths": prefix_lengths_tensor,
|
||||||
}
|
}
|
||||||
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
self.cuda_graphs[bs]["graph"] = graph
|
||||||
|
|
||||||
|
@ -1202,7 +1208,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths_,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1211,7 +1217,13 @@ class FlashCausalLM(Model):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
prefix_lengths=prefix_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1219,7 +1231,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths_tensor,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1358,18 +1370,26 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
cu_seqlen_prefill = torch.tensor(
|
||||||
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=1,
|
||||||
|
max_k=seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=torch.tensor(
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
|
||||||
),
|
|
||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
@ -1449,7 +1469,8 @@ class FlashCausalLM(Model):
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = input_lengths + prefix_lens_tensor
|
# TODO
|
||||||
|
# input_lengths = input_lengths + prefix_lens_tensor
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
|
@ -1464,7 +1485,14 @@ class FlashCausalLM(Model):
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
prefix_lens_tensor=prefix_lens_tensor,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=max_s,
|
||||||
|
max_k=max_k,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1472,7 +1500,7 @@ class FlashCausalLM(Model):
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
|
|
Loading…
Reference in New Issue