feat(server): pre-allocate past key values for flash causal LM (#412)
This commit is contained in:
parent
ca650e5bff
commit
5ce89059f8
|
@ -1,9 +1,9 @@
|
||||||
flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7
|
flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f
|
||||||
|
|
||||||
flash-attention:
|
flash-attention:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install packaging
|
pip install packaging
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git
|
git clone https://github.com/OlivierDehaene/flash-attention.git
|
||||||
|
|
||||||
build-flash-attention: flash-attention
|
build-flash-attention: flash-attention
|
||||||
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
cd flash-attention && git fetch && git checkout $(flash_att_commit)
|
||||||
|
|
|
@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
|
@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
layer_past[:, 0],
|
layer_past[:, 0],
|
||||||
layer_past[:, 1],
|
layer_past[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module):
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module):
|
||||||
normed_hidden_states,
|
normed_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_present_indices,
|
||||||
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
|
len(input_ids),
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
|
@ -360,25 +375,36 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
past_key_values[:, i],
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.layers),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
hidden_states, present = self.model(
|
hidden_states, present = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = qkv[:, 1:]
|
layer_past[...] = qkv[:, 1:]
|
||||||
|
|
||||||
|
@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
query = qkv[:, 0]
|
query = qkv[:, 0]
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
layer_past[past_present_indices] = qkv[:, 1:]
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
layer_past[:, 0],
|
layer_past[:, 0],
|
||||||
layer_past[:, 1],
|
layer_past[:, 1],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module):
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
|
@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module):
|
||||||
ln1_hidden_states,
|
ln1_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
|
@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
|
len(input_ids),
|
||||||
len(self.layers),
|
len(self.layers),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
2,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
|
@ -367,25 +385,36 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
past_key_values[:, i],
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.layers),
|
||||||
|
2,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
|
@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, 0],
|
torch.select(kv, dim=1, index=0),
|
||||||
kv[:, 1],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||||
|
@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, cos, sin)
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = kv
|
layer_past[...] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
|
@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = kv
|
layer_past[past_present_indices] = kv
|
||||||
# Expand to query shape
|
# Expand to query shape
|
||||||
kv = (
|
kv = (
|
||||||
layer_past.unsqueeze(2)
|
layer_past.unsqueeze(2)
|
||||||
|
@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
kv[:, :, 0],
|
torch.select(kv, dim=2, index=0),
|
||||||
kv[:, :, 1],
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module):
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
if self.parallel_attn:
|
if self.parallel_attn:
|
||||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module):
|
||||||
ln_hidden_states,
|
ln_hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
|
@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module):
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||||
ln_mlp, _ = self.ln_mlp(residual)
|
ln_mlp, _ = self.ln_mlp(residual)
|
||||||
|
@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
ln_attn,
|
ln_attn,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLP.
|
# MLP.
|
||||||
|
@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
past_key_values = hidden_states.new_empty(
|
past_key_values = hidden_states.new_empty(
|
||||||
(
|
(
|
||||||
|
len(input_ids),
|
||||||
len(self.h),
|
len(self.h),
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
*self.cache_size,
|
*self.cache_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
# Avoid to index in each layer
|
# Avoid to index in each layer
|
||||||
|
@ -620,25 +651,34 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(
|
||||||
|
pre_allocate_past_size,
|
||||||
|
len(self.h),
|
||||||
|
*self.cache_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
qkv = self.c_attn(hidden_states)
|
qkv = self.c_attn(hidden_states)
|
||||||
|
|
||||||
|
@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if layer_past_present_indices is None:
|
if prefill:
|
||||||
# Copy to layer past
|
# Copy to layer past
|
||||||
layer_past[...] = key_value
|
layer_past[...] = key_value
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
|
@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens,
|
end_seq,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Add present to the layer_past tensor at the correct indices
|
# Add present to the layer_past tensor at the correct indices
|
||||||
layer_past[layer_past_present_indices] = key_value
|
layer_past[past_present_indices] = key_value
|
||||||
# Expand from 1 to num_heads
|
# Expand from 1 to num_heads
|
||||||
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
query,
|
query,
|
||||||
key_value[:, 0],
|
torch.select(key_value, dim=1, index=0),
|
||||||
key_value[:, 1],
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlens_q,
|
start_seq_q,
|
||||||
cu_seqlens,
|
end_seq_q,
|
||||||
|
start_seq,
|
||||||
|
end_seq,
|
||||||
1,
|
1,
|
||||||
max_s,
|
max_s,
|
||||||
0.0,
|
0.0,
|
||||||
|
@ -277,21 +285,27 @@ class Block(nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states = self.attn(
|
hidden_states = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past,
|
layer_past,
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, residual = self.ln_2(hidden_states, residual)
|
hidden_states, residual = self.ln_2(hidden_states, residual)
|
||||||
|
@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_present_indices,
|
||||||
|
past_key_values=None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
|
@ -352,45 +369,43 @@ class FlashSantacoderModel(nn.Module):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
assert pre_allocate_past_size is not None
|
||||||
|
|
||||||
|
prefill = True
|
||||||
|
|
||||||
# Create past tensor
|
# Create past tensor
|
||||||
past_key_values = hidden_states.new_empty(
|
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
|
||||||
(
|
past_key_values = hidden_states.new_zeros(
|
||||||
len(self.h),
|
(len(input_ids), len(self.h), 2, 1, self.head_size)
|
||||||
len(hidden_states)
|
|
||||||
if pre_allocate_past_size is None
|
|
||||||
else pre_allocate_past_size,
|
|
||||||
2,
|
|
||||||
1,
|
|
||||||
self.head_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
layer_past_present_indices = None
|
|
||||||
slice_past_index = len(hidden_states)
|
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
# Create indices from cumulative sequence lengths
|
prefill = False
|
||||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
|
||||||
slice_past_index = None
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.h):
|
||||||
# We added padding that we now need to slice
|
|
||||||
layer_past_key_values = (
|
|
||||||
past_key_values[i]
|
|
||||||
if slice_past_index is None
|
|
||||||
else past_key_values[i, :slice_past_index]
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
layer_past_key_values,
|
torch.select(past_key_values, dim=1, index=i),
|
||||||
layer_past_present_indices,
|
past_present_indices,
|
||||||
cu_seqlens_q,
|
prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
present = past_key_values
|
||||||
|
# Create padded past tensor
|
||||||
|
past_key_values = hidden_states.new_empty(
|
||||||
|
(pre_allocate_past_size, len(self.h), 2, 1, self.head_size)
|
||||||
|
)
|
||||||
|
# We slice only once instead of at every layer
|
||||||
|
past_key_values[past_present_indices] = present
|
||||||
|
|
||||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
hidden_states, present = self.transformer(
|
hidden_states, present = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlens,
|
start_seq,
|
||||||
cu_seqlens_q,
|
end_seq,
|
||||||
|
start_seq_q,
|
||||||
|
end_seq_q,
|
||||||
max_s,
|
max_s,
|
||||||
|
past_present_indices,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,8 +3,6 @@ import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
|
@ -34,10 +32,21 @@ class FlashCausalLMBatch(Batch):
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
|
||||||
# cumulative sequence lengths
|
# Indices to copy present to the correct indices is the pre-allocated past key values
|
||||||
cu_seqlens: torch.Tensor
|
past_present_indices: torch.Tensor
|
||||||
# cumulative query sequence lengths, only used in decode
|
|
||||||
cu_seqlens_q: Optional[torch.Tensor]
|
# tensor of length b holding starting offset of each sequence
|
||||||
|
start_seq: torch.Tensor
|
||||||
|
# tensor of length b holding ending offset of each sequence
|
||||||
|
end_seq: torch.Tensor
|
||||||
|
# tensor of length b holding starting offset of each sequence, only used in prefill
|
||||||
|
start_seq_prefill: Optional[torch.Tensor]
|
||||||
|
# tensor of length b holding ending offset of each sequence, only used in prefill
|
||||||
|
end_seq_prefill: Optional[torch.Tensor]
|
||||||
|
# tensor of length b holding starting offset of each query sequence, only used in decode
|
||||||
|
start_seq_q: Optional[torch.Tensor]
|
||||||
|
# tensor of length b holding ending offset of each query sequence, only used in decode
|
||||||
|
end_seq_q: Optional[torch.Tensor]
|
||||||
# past key values, only used in decode
|
# past key values, only used in decode
|
||||||
past_key_values: Optional[torch.Tensor]
|
past_key_values: Optional[torch.Tensor]
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
@ -90,7 +99,11 @@ class FlashCausalLMBatch(Batch):
|
||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlens = [0]
|
past_present_indices = []
|
||||||
|
start_seq = []
|
||||||
|
end_seq = []
|
||||||
|
start_seq_prefill = []
|
||||||
|
end_seq_prefill = []
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
@ -110,9 +123,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
cumulative_max_length = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
@ -138,7 +151,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
start_seq_prefill.append(cumulative_length)
|
||||||
|
end_seq_prefill.append(cumulative_length + input_length)
|
||||||
|
start_seq.append(cumulative_max_length)
|
||||||
|
end_seq.append(cumulative_max_length + input_length)
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
|
@ -168,9 +184,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
|
request_past_present_indices = torch.arange(
|
||||||
|
cumulative_max_length,
|
||||||
|
cumulative_max_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
past_present_indices.append(request_past_present_indices)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
|
# Remove one as the first token des not have a past
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
max_tokens += input_length + max_new_tokens
|
cumulative_max_length += input_length + max_new_tokens - 1
|
||||||
max_length = max(max_length, input_length + max_new_tokens)
|
max_length = max(max_length, input_length + max_new_tokens)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
|
@ -184,26 +208,45 @@ class FlashCausalLMBatch(Batch):
|
||||||
for i, input_ids in enumerate(all_input_ids):
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
|
# Create tensors on device
|
||||||
|
all_input_ids_tensor = torch.tensor(
|
||||||
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||||
|
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
|
|
||||||
|
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
|
||||||
|
|
||||||
|
start_seq_prefill = torch.tensor(
|
||||||
|
start_seq_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
end_seq_prefill = torch.tensor(
|
||||||
|
end_seq_prefill, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
# Create tensors on device
|
past_present_indices = past_present_indices[0]
|
||||||
|
|
||||||
|
start_seq_prefill = start_seq
|
||||||
|
end_seq_prefill = end_seq
|
||||||
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
all_input_ids_tensor = torch.tensor(
|
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
past_present_indices = torch.tensor(
|
||||||
|
past_present_indices, device=device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = cu_seqlens[1:] - 1
|
prefill_next_token_indices = end_seq_prefill - 1
|
||||||
elif no_prefill_logprobs:
|
elif no_prefill_logprobs:
|
||||||
prefill_head_indices = cu_seqlens[1:] - 1
|
prefill_head_indices = end_seq_prefill - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.tensor(
|
prefill_head_indices = torch.tensor(
|
||||||
|
@ -219,8 +262,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlens=cu_seqlens,
|
past_present_indices=past_present_indices,
|
||||||
cu_seqlens_q=None,
|
start_seq=start_seq,
|
||||||
|
end_seq=end_seq,
|
||||||
|
start_seq_prefill=start_seq_prefill,
|
||||||
|
end_seq_prefill=end_seq_prefill,
|
||||||
|
start_seq_q=None,
|
||||||
|
end_seq_q=None,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
|
@ -233,7 +281,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_tokens=cumulative_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -244,10 +292,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
single_request = len(request_ids) == 1
|
device = self.input_ids.device
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_max_length = 0
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
@ -255,11 +303,17 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Used to index into tensors
|
# Used to index into tensors
|
||||||
indices = []
|
indices = []
|
||||||
|
|
||||||
|
# past indices to keep
|
||||||
|
past_indices = torch.zeros(
|
||||||
|
self.past_key_values.shape[0], dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Create on CPU to only move to GPU once instead of at every copy
|
# Create on CPU to only move to GPU once instead of at every copy
|
||||||
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
|
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
|
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
|
||||||
|
start_seq_q = self.start_seq_q[: len(request_ids)]
|
||||||
|
end_seq_q = self.end_seq_q[: len(request_ids)]
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = []
|
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
|
@ -270,8 +324,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
for i, request_id in enumerate(request_ids):
|
for i, request_id in enumerate(request_ids):
|
||||||
idx = self.requests_idx_mapping[request_id]
|
idx = self.requests_idx_mapping[request_id]
|
||||||
indices.append(idx)
|
indices.append(idx)
|
||||||
|
@ -281,16 +333,8 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
|
||||||
# Copy to tensor (CPU)
|
|
||||||
cu_seqlens[i + 1] = cumulative_length + request_input_length
|
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
# Slice from past
|
|
||||||
past_key_values.append(
|
|
||||||
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]]
|
|
||||||
)
|
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
|
@ -300,39 +344,32 @@ class FlashCausalLMBatch(Batch):
|
||||||
stopping_criteria = self.stopping_criterias[idx]
|
stopping_criteria = self.stopping_criterias[idx]
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
cumulative_length += request_input_length
|
remaining_tokens = (
|
||||||
max_tokens += request_input_length + (
|
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
if single_request:
|
# Copy to tensor (CPU)
|
||||||
# Preallocate tensor for bs = 1 case
|
start_seq[i] = cumulative_max_length
|
||||||
past_key_values = F.pad(
|
end_seq[i] = cumulative_max_length + request_input_length
|
||||||
past_key_values[0],
|
|
||||||
(
|
# Set slice
|
||||||
0,
|
past_indices[
|
||||||
0,
|
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
|
||||||
0,
|
] = True
|
||||||
0,
|
|
||||||
0,
|
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||||
0,
|
|
||||||
0,
|
|
||||||
stopping_criterias[0].max_new_tokens
|
|
||||||
- stopping_criterias[0].current_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Cat all past
|
|
||||||
past_key_values = torch.cat(past_key_values, dim=1)
|
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
|
past_key_values = self.past_key_values[past_indices]
|
||||||
|
|
||||||
# Move to GPU now that we have the whole tensor
|
# Move to GPU now that we have the whole tensor
|
||||||
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
|
start_seq = start_seq.to(device)
|
||||||
|
end_seq = end_seq.to(device)
|
||||||
|
past_present_indices = end_seq - 1
|
||||||
|
|
||||||
return FlashCausalLMBatch(
|
return FlashCausalLMBatch(
|
||||||
batch_id=self.batch_id,
|
batch_id=self.batch_id,
|
||||||
|
@ -340,8 +377,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlens=cu_seqlens,
|
past_present_indices=past_present_indices,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
start_seq=start_seq,
|
||||||
|
end_seq=end_seq,
|
||||||
|
start_seq_prefill=None,
|
||||||
|
end_seq_prefill=None,
|
||||||
|
start_seq_q=start_seq_q,
|
||||||
|
end_seq_q=end_seq_q,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
|
@ -354,7 +396,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=max_tokens,
|
max_tokens=cumulative_max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -371,10 +413,12 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
cu_seqlens = [0]
|
start_seq = batches[0].start_seq.new_empty(total_batch_size)
|
||||||
cu_seqlens_q = torch.arange(
|
end_seq = batches[0].end_seq.new_empty(total_batch_size)
|
||||||
0, total_batch_size + 1, device=device, dtype=torch.int32
|
start_seq_q = torch.arange(
|
||||||
|
0, total_batch_size, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
end_seq_q = start_seq_q + 1
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
|
|
||||||
|
@ -389,7 +433,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_batch_size = 0
|
cumulative_batch_size = 0
|
||||||
cumulative_length = 0
|
|
||||||
max_tokens = 0
|
max_tokens = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
|
|
||||||
|
@ -410,18 +453,10 @@ class FlashCausalLMBatch(Batch):
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
start_seq[start_index:end_index] = batch.start_seq + max_tokens
|
||||||
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
|
end_seq[start_index:end_index] = batch.end_seq + max_tokens
|
||||||
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
|
||||||
|
|
||||||
if len(batch) != 1:
|
max_seqlen = max(max_seqlen, batch.max_seqlen)
|
||||||
past_key_values.append(batch.past_key_values)
|
|
||||||
else:
|
|
||||||
# past was pre-allocated for this batch
|
|
||||||
# We need to slice to remove the padding
|
|
||||||
past_key_values.append(
|
|
||||||
batch.past_key_values[:, : batch.input_lengths[0]]
|
|
||||||
)
|
|
||||||
|
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
|
|
||||||
|
@ -431,9 +466,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
|
||||||
stopping_criterias.extend(batch.stopping_criterias)
|
stopping_criterias.extend(batch.stopping_criterias)
|
||||||
|
past_key_values.append(batch.past_key_values)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += batch.cu_seqlens[-1]
|
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
max_tokens += batch.max_tokens
|
max_tokens += batch.max_tokens
|
||||||
max_length = max(
|
max_length = max(
|
||||||
|
@ -448,6 +483,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
past_key_values = torch.cat(past_key_values, dim=0)
|
||||||
|
past_present_indices = end_seq - 1
|
||||||
|
|
||||||
all_input_ids_tensor = torch.zeros(
|
all_input_ids_tensor = torch.zeros(
|
||||||
(total_batch_size, max_length), dtype=torch.int64, device=device
|
(total_batch_size, max_length), dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
@ -463,11 +501,6 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
# Cat past
|
|
||||||
past_key_values = torch.cat(past_key_values, dim=1)
|
|
||||||
# Create final tensor on GPU
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype=dtype, device=device
|
next_token_chooser_parameters, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
@ -478,8 +511,13 @@ class FlashCausalLMBatch(Batch):
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlens=cu_seqlens,
|
past_present_indices=past_present_indices,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
start_seq=start_seq,
|
||||||
|
end_seq=end_seq,
|
||||||
|
start_seq_prefill=None,
|
||||||
|
end_seq_prefill=None,
|
||||||
|
start_seq_q=start_seq_q,
|
||||||
|
end_seq_q=end_seq_q,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=None,
|
prefill_head_indices=None,
|
||||||
prefill_next_token_indices=None,
|
prefill_next_token_indices=None,
|
||||||
|
@ -550,9 +588,12 @@ class FlashCausalLM(Model):
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
start_seq: torch.Tensor,
|
||||||
cu_seqlens_q: Optional[torch.Tensor],
|
end_seq: torch.Tensor,
|
||||||
|
start_seq_q: Optional[torch.Tensor],
|
||||||
|
end_seq_q: Optional[torch.Tensor],
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
past_present_indices: torch.Tensor,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
@ -561,9 +602,12 @@ class FlashCausalLM(Model):
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlens=cu_seqlens,
|
start_seq=start_seq,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
end_seq=end_seq,
|
||||||
|
start_seq_q=start_seq_q,
|
||||||
|
end_seq_q=end_seq_q,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
past_present_indices=past_present_indices,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
pre_allocate_past_size=pre_allocate_past_size,
|
pre_allocate_past_size=pre_allocate_past_size,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
|
@ -575,23 +619,27 @@ class FlashCausalLM(Model):
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
single_request = len(batch) == 1
|
|
||||||
|
|
||||||
if prefill and single_request:
|
if prefill:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
# == number of tokens + max_new_tokens
|
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
||||||
pre_allocate_past_size = (
|
pre_allocate_past_size = batch.max_tokens
|
||||||
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
start_seq = batch.start_seq_prefill
|
||||||
)
|
end_seq = batch.end_seq_prefill
|
||||||
else:
|
else:
|
||||||
pre_allocate_past_size = None
|
pre_allocate_past_size = None
|
||||||
|
start_seq = batch.start_seq
|
||||||
|
end_seq = batch.end_seq
|
||||||
|
|
||||||
out, present = self.forward(
|
out, present = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
batch.cu_seqlens,
|
start_seq,
|
||||||
batch.cu_seqlens_q,
|
end_seq,
|
||||||
|
batch.start_seq_q,
|
||||||
|
batch.end_seq_q,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
|
batch.past_present_indices,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
|
@ -614,55 +662,19 @@ class FlashCausalLM(Model):
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
# Create batch.cu_seqlens_q for decode
|
# Create batch.start_seq_q and batch.end_seq_q for decode
|
||||||
batch.cu_seqlens_q = torch.arange(
|
batch.start_seq_q = torch.arange(
|
||||||
0, len(batch) + 1, device=self.device, dtype=torch.int32
|
0, len(batch), device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
batch.end_seq_q = batch.start_seq_q + 1
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
|
# We do not need start_seq_prefill and end_seq_prefill anymore
|
||||||
|
batch.start_seq_prefill = None
|
||||||
|
batch.end_seq_prefill = None
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
|
|
||||||
# Prepare past for next decode
|
|
||||||
if len(batch) > 1:
|
|
||||||
# Used to slice next batch past
|
|
||||||
past_indices = torch.empty(
|
|
||||||
present.shape[1], dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
batch.past_key_values = present.new_empty(
|
|
||||||
(
|
|
||||||
present.shape[0],
|
|
||||||
present.shape[1] + len(batch.requests),
|
|
||||||
*present.shape[2:],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow
|
|
||||||
# and will run asynchronously while we do the next for loop
|
|
||||||
cumulative_length = 0
|
|
||||||
for i, input_length in enumerate(batch.input_lengths):
|
|
||||||
# Indexing metadata
|
|
||||||
start_index = cumulative_length
|
|
||||||
end_index = cumulative_length + input_length
|
|
||||||
|
|
||||||
# Indices to copy present at the correct place in past_key_values
|
|
||||||
torch.arange(
|
|
||||||
start_index + i,
|
|
||||||
end_index + i,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.device,
|
|
||||||
out=past_indices[start_index:end_index],
|
|
||||||
)
|
|
||||||
cumulative_length += input_length
|
|
||||||
|
|
||||||
# Copy from present to past_key_values
|
|
||||||
batch.past_key_values[:, past_indices] = present
|
|
||||||
|
|
||||||
# Initialize past_key_values in prefill for len(batch) == 1
|
|
||||||
elif prefill:
|
|
||||||
# present is already pre-padded
|
|
||||||
batch.past_key_values = present
|
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
|
||||||
|
@ -685,6 +697,7 @@ class FlashCausalLM(Model):
|
||||||
input_length,
|
input_length,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
|
@ -718,7 +731,8 @@ class FlashCausalLM(Model):
|
||||||
# Set values in batch
|
# Set values in batch
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q
|
batch.past_present_indices = batch.end_seq
|
||||||
|
batch.end_seq = batch.end_seq + 1
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
|
@ -843,6 +857,7 @@ class FlashCausalLM(Model):
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
|
batch.past_key_values = present
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
|
Loading…
Reference in New Issue