Fixing prefix caching for flashdecoding.
This commit is contained in:
parent
7f1816a4e1
commit
65b94a69bd
|
@ -9,26 +9,48 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
|||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def __init__(self, input_lengths):
|
||||
def __init__(
|
||||
self,
|
||||
input_lengths,
|
||||
prefix_lengths,
|
||||
cu_seqlen_q=None,
|
||||
max_q=None,
|
||||
max_k=None,
|
||||
):
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_lengths = prefix_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
if cu_seqlen_q is None:
|
||||
cu_seqlen_q = torch.arange(
|
||||
shape[0] + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
max_q = 1
|
||||
else:
|
||||
assert max_q is not None
|
||||
assert max_k is not None
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
# if max_q is not None and max_q < 1000 and max_q > 1:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])
|
||||
total = self.input_lengths + self.prefix_lengths
|
||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
self.cu_seqlen_k = cu_seqlen_k
|
||||
self.max_q = max_q
|
||||
self.max_k = max_k
|
||||
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
|
@ -39,6 +61,11 @@ else:
|
|||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cu_seqlen_q: torch.Tensor
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def clamp(self, max):
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
|
|
|
@ -222,12 +222,10 @@ if ATTENTION == "flashinfer":
|
|||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
|
@ -250,12 +248,10 @@ elif V2:
|
|||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
|
@ -266,17 +262,17 @@ elif V2:
|
|||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
|
|
|
@ -32,6 +32,7 @@ from text_generation_server.layers.attention import (
|
|||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -194,7 +195,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
|
@ -218,12 +219,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
|
@ -235,7 +234,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
@ -375,7 +374,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
):
|
||||
|
@ -390,7 +389,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
@ -479,7 +478,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
|
@ -504,7 +503,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
)
|
||||
|
@ -548,7 +547,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
|
@ -562,7 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
|
|
|
@ -283,7 +283,6 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
|
||||
request_position_ids = torch.arange(
|
||||
prefix_len, orig_input_length, dtype=torch.int32
|
||||
)
|
||||
|
@ -1158,8 +1157,15 @@ class FlashCausalLM(Model):
|
|||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"prefix_lengths": prefix_lengths_tensor,
|
||||
}
|
||||
input_lengths_ = Seqlen(input_lengths=input_lengths_tensor)
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
|
@ -1202,7 +1208,7 @@ class FlashCausalLM(Model):
|
|||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths_,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
|
@ -1211,7 +1217,13 @@ class FlashCausalLM(Model):
|
|||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor)
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
prefix_lengths=prefix_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1219,7 +1231,7 @@ class FlashCausalLM(Model):
|
|||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths_tensor,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
|
@ -1358,18 +1370,26 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=1,
|
||||
max_k=seqlen,
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
),
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=input_lengths,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
|
@ -1449,7 +1469,8 @@ class FlashCausalLM(Model):
|
|||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
# TODO
|
||||
# input_lengths = input_lengths + prefix_lens_tensor
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
|
@ -1464,7 +1485,14 @@ class FlashCausalLM(Model):
|
|||
prefix_lens=batch.prefix_lens,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
):
|
||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
prefix_lengths=prefix_lens_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=max_s,
|
||||
max_k=max_k,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
@ -1472,7 +1500,7 @@ class FlashCausalLM(Model):
|
|||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
|
|
Loading…
Reference in New Issue