feat(server): pre-allocate past key values for flash causal LM (#412)

This commit is contained in:
OlivierDehaene 2023-06-12 18:30:29 +02:00 committed by GitHub
parent ca650e5bff
commit 5ce89059f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 494 additions and 345 deletions

View File

@ -1,9 +1,9 @@
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
flash-attention:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git
git clone https://github.com/OlivierDehaene/flash-attention.git
build-flash-attention: flash-attention
cd flash-attention && git fetch && git checkout $(flash_att_commit)

View File

@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
layer_past[:, 0],
layer_past[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
# faster post attention rms norm
@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -360,25 +375,36 @@ class FlashLlamaModel(torch.nn.Module):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values
@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
hidden_states, present = self.model(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv[:, 1:]
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[:, 0],
layer_past[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
ln1_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -367,25 +385,36 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values
@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
hidden_states, present = self.gpt_neox(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
layer_past[past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
ln_hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
mlp_output = self.mlp(ln_hidden_states)
@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
hidden_states,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.post_attention_layernorm(
@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
ln_attn,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
# MLP.
@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
*self.cache_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -620,25 +651,34 @@ class FlashRWModel(FlashRWPreTrainedModel):
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.h),
*self.cache_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -7,6 +7,7 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
def forward(
self,
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
qkv = self.c_attn(hidden_states)
@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
key_value = key_value.view(-1, 2, 1, self.head_size)
# Prefill
if layer_past_present_indices is None:
if prefill:
# Copy to layer past
layer_past[...] = key_value
# Expand from 1 to num_heads
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens,
cu_seqlens,
start_seq,
end_seq,
start_seq,
end_seq,
max_s,
max_s,
0.0,
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = key_value
layer_past[past_present_indices] = key_value
# Expand from 1 to num_heads
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
flash_attn_cuda.fwd(
query,
key_value[:, 0],
key_value[:, 1],
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlens_q,
cu_seqlens,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
@ -277,21 +285,27 @@ class Block(nn.Module):
self,
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
past_present_indices,
prefill,
)
hidden_states, residual = self.ln_2(hidden_states, residual)
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -352,45 +369,43 @@ class FlashSantacoderModel(nn.Module):
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
1,
self.head_size,
)
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_zeros(
(len(input_ids), len(self.h), 2, 1, self.head_size)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
prefill = False
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlens,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
torch.select(past_key_values, dim=1, index=i),
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)

View File

@ -3,8 +3,6 @@ import torch.distributed
import numpy as np
from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
@ -34,10 +32,21 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor
position_ids: torch.Tensor
# cumulative sequence lengths
cu_seqlens: torch.Tensor
# cumulative query sequence lengths, only used in decode
cu_seqlens_q: Optional[torch.Tensor]
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices: torch.Tensor
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q: Optional[torch.Tensor]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
max_seqlen: int
@ -90,7 +99,11 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
cu_seqlens = [0]
past_present_indices = []
start_seq = []
end_seq = []
start_seq_prefill = []
end_seq_prefill = []
max_seqlen = 0
input_lengths = []
@ -110,9 +123,9 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0
max_tokens = 0
max_length = 0
# Parse batch
@ -138,7 +151,10 @@ class FlashCausalLMBatch(Batch):
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters)
@ -168,9 +184,17 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices)
# Update
# Remove one as the first token des not have a past
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
cumulative_max_length += input_length + max_new_tokens - 1
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -184,26 +208,45 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32
)
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
# Create tensors on device
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
past_present_indices = torch.tensor(
past_present_indices, device=device, dtype=torch.int64
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlens[1:] - 1
prefill_next_token_indices = end_seq_prefill - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlens[1:] - 1
prefill_head_indices = end_seq_prefill - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
@ -219,8 +262,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=None,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
start_seq_q=None,
end_seq_q=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
@ -233,7 +281,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
max_tokens=cumulative_max_length,
)
@tracer.start_as_current_span("filter")
@ -244,10 +292,10 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self):
return self
single_request = len(request_ids) == 1
device = self.input_ids.device
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
# New values after filtering
requests_idx_mapping = {}
@ -255,11 +303,17 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices = []
# past indices to keep
past_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
@ -270,8 +324,6 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
max_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
@ -281,16 +333,8 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
# Slice from past
past_key_values.append(
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
)
all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(request_input_length)
@ -300,39 +344,32 @@ class FlashCausalLMBatch(Batch):
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
cumulative_length += request_input_length
max_tokens += request_input_length + (
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
if single_request:
# Preallocate tensor for bs = 1 case
past_key_values = F.pad(
past_key_values[0],
(
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
)
else:
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Copy to tensor (CPU)
start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length
# Set slice
past_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[past_indices]
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
start_seq = start_seq.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
return FlashCausalLMBatch(
batch_id=self.batch_id,
@ -340,8 +377,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
@ -354,7 +396,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
max_tokens=cumulative_max_length,
)
@classmethod
@ -371,10 +413,12 @@ class FlashCausalLMBatch(Batch):
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
cu_seqlens = [0]
cu_seqlens_q = torch.arange(
0, total_batch_size + 1, device=device, dtype=torch.int32
start_seq = batches[0].start_seq.new_empty(total_batch_size)
end_seq = batches[0].end_seq.new_empty(total_batch_size)
start_seq_q = torch.arange(
0, total_batch_size, device=device, dtype=torch.int32
)
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = []
@ -389,7 +433,6 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
max_length = 0
@ -410,18 +453,10 @@ class FlashCausalLMBatch(Batch):
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
start_seq[start_index:end_index] = batch.start_seq + max_tokens
end_seq[start_index:end_index] = batch.end_seq + max_tokens
if len(batch) != 1:
past_key_values.append(batch.past_key_values)
else:
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
max_seqlen = max(max_seqlen, batch.max_seqlen)
all_input_ids.extend(batch.all_input_ids)
@ -431,9 +466,9 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
@ -448,6 +483,9 @@ class FlashCausalLMBatch(Batch):
),
)
past_key_values = torch.cat(past_key_values, dim=0)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
@ -463,11 +501,6 @@ class FlashCausalLMBatch(Batch):
cumulative_batch_size += len(batch)
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
)
@ -478,8 +511,13 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
@ -550,9 +588,12 @@ class FlashCausalLM(Model):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor],
start_seq: torch.Tensor,
end_seq: torch.Tensor,
start_seq_q: Optional[torch.Tensor],
end_seq_q: Optional[torch.Tensor],
max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -561,9 +602,12 @@ class FlashCausalLM(Model):
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
cu_seqlens_q=cu_seqlens_q,
start_seq=start_seq,
end_seq=end_seq,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
@ -575,23 +619,27 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1
if prefill and single_request:
if prefill:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = batch.max_tokens
start_seq = batch.start_seq_prefill
end_seq = batch.end_seq_prefill
else:
pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
out, present = self.forward(
batch.input_ids,
batch.position_ids,
batch.cu_seqlens,
batch.cu_seqlens_q,
start_seq,
end_seq,
batch.start_seq_q,
batch.end_seq_q,
batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
@ -614,55 +662,19 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.cu_seqlens_q for decode
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
# Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch))
# We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None
batch.end_seq_prefill = None
else:
prefill_logprobs = None
next_position_ids = batch.position_ids
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
past_indices = torch.empty(
present.shape[1], dtype=torch.int64, device=self.device
)
batch.past_key_values = present.new_empty(
(
present.shape[0],
present.shape[1] + len(batch.requests),
*present.shape[2:],
)
)
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
# and will run asynchronously while we do the next for loop
cumulative_length = 0
for i, input_length in enumerate(batch.input_lengths):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
# Indices to copy present at the correct place in past_key_values
torch.arange(
start_index + i,
end_index + i,
dtype=torch.int64,
device=self.device,
out=past_indices[start_index:end_index],
)
cumulative_length += input_length
# Copy from present to past_key_values
batch.past_key_values[:, past_indices] = present
# Initialize past_key_values in prefill for len(batch) == 1
elif prefill:
# present is already pre-padded
batch.past_key_values = present
# Cumulative length
cumulative_length = 0
@ -685,6 +697,7 @@ class FlashCausalLM(Model):
input_length,
all_input_ids,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
end_index = cumulative_length + input_length
@ -718,7 +731,8 @@ class FlashCausalLM(Model):
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
batch.past_present_indices = batch.end_seq
batch.end_seq = batch.end_seq + 1
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -843,6 +857,7 @@ class FlashCausalLM(Model):
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None