diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7184388..ee1bd01 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -25,6 +25,7 @@ from torch.nn import functional as F from torch import nn from transformers.activations import ACT2FN +from typing import Optional # Flash attention imports import rotary_emb @@ -554,7 +555,8 @@ class FlashLlamaModel(torch.nn.Module): position_ids, cu_seqlens, max_s, - past_key_values=None, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.embed_tokens(input_ids) @@ -564,7 +566,9 @@ class FlashLlamaModel(torch.nn.Module): past_key_values = hidden_states.new_empty( ( len(self.layers), - len(hidden_states), + len(hidden_states) + if pre_allocate_past_size is None + else pre_allocate_past_size, 2, self.num_heads, self.head_size, @@ -572,6 +576,7 @@ class FlashLlamaModel(torch.nn.Module): ) layer_past_present_indices = None cu_seqlens_q = None + slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths @@ -579,6 +584,7 @@ class FlashLlamaModel(torch.nn.Module): cu_seqlens_q = torch.arange( cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device ) + slice_past_index = None # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -588,6 +594,13 @@ class FlashLlamaModel(torch.nn.Module): residual = None for i, layer in enumerate(self.layers): + # We added padding that now need to slice + layer_past_key_values = ( + past_key_values[i] + if slice_past_index is None + else past_key_values[i, :slice_past_index] + ) + hidden_states, residual = layer( hidden_states, residual, @@ -595,7 +608,7 @@ class FlashLlamaModel(torch.nn.Module): sin, cu_seqlens, max_s, - past_key_values[i], + layer_past_key_values, layer_past_present_indices, cu_seqlens_q, ) @@ -638,10 +651,16 @@ class FlashLlamaForCausalLM(torch.nn.Module): position_ids, cu_seqlens, max_s, - past_key_values=None, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states, present = self.model( - input_ids, position_ids, cu_seqlens, max_s, past_key_values + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values, + pre_allocate_past_size, ) logits = self.lm_head(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 8de582e..6054584 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,6 +27,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig +from typing import Optional # Flash attention imports import rotary_emb @@ -618,6 +619,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens, max_s, past_key_values=None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.embed_in(input_ids) @@ -627,7 +629,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): past_key_values = hidden_states.new_empty( ( len(self.layers), - len(hidden_states), + len(hidden_states) + if pre_allocate_past_size is None + else pre_allocate_past_size, 2, self.num_heads, self.head_size, @@ -635,6 +639,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ) layer_past_present_indices = None cu_seqlens_q = None + slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths @@ -642,6 +647,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens_q = torch.arange( cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device ) + slice_past_index = None # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -651,6 +657,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual = None for i, layer in enumerate(self.layers): + # We added padding that now need to slice + layer_past_key_values = ( + past_key_values[i] + if slice_past_index is None + else past_key_values[i, :slice_past_index] + ) + hidden_states, residual = layer( hidden_states, residual, @@ -658,7 +671,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): sin, cu_seqlens, max_s, - past_key_values[i], + layer_past_key_values, layer_past_present_indices, cu_seqlens_q, ) @@ -714,10 +727,16 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids, cu_seqlens, max_s, - past_key_values=None, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states, present = self.gpt_neox( - input_ids, position_ids, cu_seqlens, max_s, past_key_values + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values, + pre_allocate_past_size, ) logits = self.embed_out(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 793b3d1..736f896 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN +from typing import Optional # Flash attention imports import flash_attn_cuda @@ -484,7 +485,8 @@ class FlashSantacoderModel(nn.Module): position_ids, cu_seqlens, max_s, - past_key_values=None, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.tp_embeddings: @@ -496,7 +498,9 @@ class FlashSantacoderModel(nn.Module): past_key_values = hidden_states.new_empty( ( len(self.h), - len(hidden_states), + len(hidden_states) + if pre_allocate_past_size is None + else pre_allocate_past_size, 2, 1, self.head_size, @@ -504,6 +508,7 @@ class FlashSantacoderModel(nn.Module): ) layer_past_present_indices = None cu_seqlens_q = None + slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths @@ -511,15 +516,23 @@ class FlashSantacoderModel(nn.Module): cu_seqlens_q = torch.arange( cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device ) + slice_past_index = None residual = None for i, layer in enumerate(self.h): + # We added padding that now need to slice + layer_past_key_values = ( + past_key_values[i] + if slice_past_index is None + else past_key_values[i, :slice_past_index] + ) + hidden_states, residual = layer( hidden_states, residual, cu_seqlens, max_s, - past_key_values[i], + layer_past_key_values, layer_past_present_indices, cu_seqlens_q, ) @@ -554,10 +567,16 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids, cu_seqlens, max_s, - past_key_values=None, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states, present = self.transformer( - input_ids, position_ids, cu_seqlens, max_s, past_key_values + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values, + pre_allocate_past_size, ) logits = self.lm_head(hidden_states) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0e2fbaa..7be1312 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -142,6 +142,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + past_pad=None, ) @tracer.start_as_current_span("filter") @@ -188,8 +189,10 @@ class FlashCausalLMBatch(Batch): cu_seqlens.append(cumulative_length + request_input_length) max_seqlen = max(max_seqlen, request_input_length) if not single_request: + # True index for past past_key_values.append(self.past_key_values[2 * idx]) - past_key_values.append(self.past_key_values[1]) + # Add one padding + past_key_values.append(self.past_pad) all_input_ids.append(self.all_input_ids[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) @@ -207,7 +210,17 @@ class FlashCausalLMBatch(Batch): # Preallocate tensor for bs = 1 case past_key_values = torch.nn.functional.pad( self.past_key_values[0], - (0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens) + ( + 0, + 0, + 0, + 0, + 0, + 0, + 0, + stopping_criterias[0].max_new_tokens + - stopping_criterias[0].current_tokens, + ), ) return FlashCausalLMBatch( @@ -270,10 +283,16 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) max_seqlen = max(max_seqlen, batch.max_seqlen) + if len(batch) != 1: past_key_values.extend(batch.past_key_values) else: - past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]]) + # past was pre-allocated for this batch + # We need to slice to remove the padding + past_key_values.append( + batch.past_key_values[:, : batch.input_lengths[0]] + ) + # Add one padding past_key_values.append(batch.past_pad) all_input_ids.extend(batch.all_input_ids) @@ -366,6 +385,7 @@ class FlashCausalLM(Model): cu_seqlens: torch.Tensor, max_s: int, past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -374,6 +394,7 @@ class FlashCausalLM(Model): cu_seqlens=cu_seqlens, max_s=max_s, past_key_values=past_key_values, + pre_allocate_past_size=pre_allocate_past_size, ) @tracer.start_as_current_span("generate_token") @@ -382,7 +403,9 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: # Shortcut when batch_size == 1 if len(batch) == 1: - # No need to slice this down + input_ids = batch.input_ids[0].view(-1) + # Slice to remove extra padding + # past_key_values = batch.past_key_values[:, :batch.input_lengths[0]] if batch.past_key_values is not None else None past_key_values = batch.past_key_values else: # Concatenate tensors @@ -393,6 +416,16 @@ class FlashCausalLM(Model): else None ) + # if prefill and bs == 1 + if past_key_values is None and len(batch) == 1: + # Ask to pre-allocate kv to its max size + # == number of tokens + max_new_tokens + pre_allocate_past_size = ( + batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens + ) + else: + pre_allocate_past_size = None + # Concatenate when prefill, torch.tensor when decode position_ids = ( torch.tensor(batch.position_ids, device=self.device) @@ -409,21 +442,28 @@ class FlashCausalLM(Model): cu_seqlens, batch.max_seqlen, past_key_values, + pre_allocate_past_size, ) # Initialize past_key_values in prefill if batch.past_key_values is None: # Initialize past padding tensor if self.past_pad is None: - self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:]) + self.past_pad = present.new_zeros( + present.shape[0], 1, *present.shape[2:] + ) # Set in batch in case it needs to be used later in concatenate() batch.past_pad = self.past_pad if len(batch) == 1: # Preallocate tensor for bs = 1 case batch.past_key_values = torch.nn.functional.pad( - present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens) + present, + (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens), ) else: + # Add padding after each sequence + # This will have the correct shape after the final past_key_values concatenation before the model + # forward batch.past_key_values = [None, self.past_pad] * len(batch) # Cumulative length @@ -555,6 +595,7 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor[i] = all_input_ids_tensor batch.max_seqlen = max(batch.max_seqlen, new_input_length) if len(batch) != 1: + # Add each sequence before its padding batch.past_key_values[i * 2] = present[:, start_index:end_index] # Cumulative sum batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 764de2a..e640113 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -29,6 +29,7 @@ tracer = trace.get_tracer(__name__) class FlashLlama(FlashCausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @@ -146,6 +147,7 @@ class FlashLlamaSharded(FlashLlama): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): + self.past_pad = None self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 259fc20..eae584a 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -33,6 +33,7 @@ class FlashNeoXSharded(FlashNeoX): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): + self.past_pad = None self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 7dcd8b0..aa1bdfb 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -28,6 +28,7 @@ tracer = trace.get_tracer(__name__) class FlashSantacoder(FlashCausalLM): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): + self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @@ -172,6 +173,7 @@ class FlashSantacoderSharded(FlashSantacoder): def __init__( self, model_id: str, revision: Optional[str] = None, quantize: bool = False ): + self.past_pad = None self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.master = self.rank == 0 if torch.cuda.is_available():