fix(server): fix past key values logic (#216)

@njhill fyi
This commit is contained in:
OlivierDehaene 2023-04-21 15:59:18 +02:00 committed by GitHub
parent 343437c7b5
commit db4cb5e4ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 20 deletions

View File

@ -25,6 +25,7 @@ from torch.nn import functional as F
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports
import rotary_emb
@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module):
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_tokens(input_ids)
@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module):
past_key_values = hidden_states.new_empty(
(
len(self.layers),
len(hidden_states),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module):
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module):
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module):
sin,
cu_seqlens,
max_s,
past_key_values[i],
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
)
@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states, present = self.model(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
)
logits = self.lm_head(hidden_states)

View File

@ -27,6 +27,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional
# Flash attention imports
import rotary_emb
@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens,
max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.embed_in(input_ids)
@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
past_key_values = hidden_states.new_empty(
(
len(self.layers),
len(hidden_states),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
self.num_heads,
self.head_size,
@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
# Avoid to index in each layer
@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
residual = None
for i, layer in enumerate(self.layers):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin,
cu_seqlens,
max_s,
past_key_values[i],
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
)
@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states, present = self.gpt_neox(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
)
logits = self.embed_out(hidden_states)

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports
import flash_attn_cuda
@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module):
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module):
past_key_values = hidden_states.new_empty(
(
len(self.h),
len(hidden_states),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
2,
1,
self.head_size,
@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module):
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module):
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
residual = None
for i, layer in enumerate(self.h):
# We added padding that now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlens,
max_s,
past_key_values[i],
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
)
@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids,
cu_seqlens,
max_s,
past_key_values=None,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states, present = self.transformer(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
input_ids,
position_ids,
cu_seqlens,
max_s,
past_key_values,
pre_allocate_past_size,
)
logits = self.lm_head(hidden_states)

View File

@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
past_pad=None,
)
@tracer.start_as_current_span("filter")
@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch):
cu_seqlens.append(cumulative_length + request_input_length)
max_seqlen = max(max_seqlen, request_input_length)
if not single_request:
# True index for past
past_key_values.append(self.past_key_values[2 * idx])
past_key_values.append(self.past_key_values[1])
# Add one padding
past_key_values.append(self.past_pad)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch):
# Preallocate tensor for bs = 1 case
past_key_values = torch.nn.functional.pad(
self.past_key_values[0],
(0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens)
(
0,
0,
0,
0,
0,
0,
0,
stopping_criterias[0].max_new_tokens
- stopping_criterias[0].current_tokens,
),
)
return FlashCausalLMBatch(
@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]])
max_seqlen = max(max_seqlen, batch.max_seqlen)
if len(batch) != 1:
past_key_values.extend(batch.past_key_values)
else:
past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]])
# past was pre-allocated for this batch
# We need to slice to remove the padding
past_key_values.append(
batch.past_key_values[:, : batch.input_lengths[0]]
)
# Add one padding
past_key_values.append(batch.past_pad)
all_input_ids.extend(batch.all_input_ids)
@ -366,6 +385,7 @@ class FlashCausalLM(Model):
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Model Forward
return self.model.forward(
@ -374,6 +394,7 @@ class FlashCausalLM(Model):
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
)
@tracer.start_as_current_span("generate_token")
@ -382,7 +403,9 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
# Shortcut when batch_size == 1
if len(batch) == 1:
# No need to slice this down
input_ids = batch.input_ids[0].view(-1)
# Slice to remove extra padding
# past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None
past_key_values = batch.past_key_values
else:
# Concatenate tensors
@ -393,6 +416,16 @@ class FlashCausalLM(Model):
else None
)
# if prefill and bs == 1
if past_key_values is None and len(batch) == 1:
# Ask to pre-allocate kv to its max size
# == number of tokens + max_new_tokens
pre_allocate_past_size = (
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
)
else:
pre_allocate_past_size = None
# Concatenate when prefill, torch.tensor when decode
position_ids = (
torch.tensor(batch.position_ids, device=self.device)
@ -409,21 +442,28 @@ class FlashCausalLM(Model):
cu_seqlens,
batch.max_seqlen,
past_key_values,
pre_allocate_past_size,
)
# Initialize past_key_values in prefill
if batch.past_key_values is None:
# Initialize past padding tensor
if self.past_pad is None:
self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:])
self.past_pad = present.new_zeros(
present.shape[0], 1, *present.shape[2:]
)
# Set in batch in case it needs to be used later in concatenate()
batch.past_pad = self.past_pad
if len(batch) == 1:
# Preallocate tensor for bs = 1 case
batch.past_key_values = torch.nn.functional.pad(
present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens)
present,
(0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens),
)
else:
# Add padding after each sequence
# This will have the correct shape after the final past_key_values concatenation before the model
# forward
batch.past_key_values = [None, self.past_pad] * len(batch)
# Cumulative length
@ -555,6 +595,7 @@ class FlashCausalLM(Model):
batch.all_input_ids_tensor[i] = all_input_ids_tensor
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
if len(batch) != 1:
# Add each sequence before its padding
batch.past_key_values[i * 2] = present[:, start_index:end_index]
# Cumulative sum
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length

View File

@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():

View File

@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():

View File

@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():