diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 382f917d..2843f273 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -38,7 +38,7 @@ class FlashCausalLMBatch(Batch): # cumulative sequence lengths cu_seqlens: List[int] max_seqlen: int - past_key_values: Optional[List[torch.Tensor]] + past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]] # All tokens all_input_ids: List[List[int]] @@ -53,6 +53,9 @@ class FlashCausalLMBatch(Batch): next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + # Constant shared tensor, ref here just so that it's accessible in concatentate() + past_pad: Optional[torch.Tensor] + def to_pb(self) -> generate_pb2.Batch: return generate_pb2.Batch( id=self.batch_id, requests=self.requests, size=len(self) @@ -149,6 +152,8 @@ class FlashCausalLMBatch(Batch): if len(requests) == len(self): return self + single_request = len(requests) == 1 + # Cumulative length cumulative_length = 0 @@ -182,7 +187,9 @@ class FlashCausalLMBatch(Batch): position_ids.append(self.position_ids[idx]) cu_seqlens.append(cumulative_length + request_input_length) max_seqlen = max(max_seqlen, request_input_length) - past_key_values.append(self.past_key_values[idx]) + if not single_request: + past_key_values.append(self.past_key_values[2 * idx]) + past_key_values.append(self.past_key_values[1]) all_input_ids.append(self.all_input_ids[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) @@ -196,6 +203,13 @@ class FlashCausalLMBatch(Batch): cumulative_length += request_input_length + if single_request: + # Preallocate tensor for bs = 1 case + past_key_values = torch.nn.functional.pad( + self.past_key_values[0], + (0, 0, 0, 0, 0, 0, 0, stopping_criterias[0].max_new_tokens - stopping_criterias[0].current_tokens) + ) + return FlashCausalLMBatch( batch_id=self.batch_id, requests=requests, @@ -256,7 +270,11 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) max_seqlen = max(max_seqlen, batch.max_seqlen) - past_key_values.extend(batch.past_key_values) + if len(batch) != 1: + past_key_values.extend(batch.past_key_values) + else: + past_key_values.append(batch.past_key_values[:, :batch.input_lengths[0]]) + past_key_values.append(batch.past_pad) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) @@ -303,6 +321,7 @@ class FlashCausalLM(Model): quantize: bool = False, decode_buffer: int = 3, ): + self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @@ -359,10 +378,8 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: # Shortcut when batch_size == 1 if len(batch) == 1: - input_ids = batch.input_ids[0].view(-1) - past_key_values = ( - batch.past_key_values[0] if batch.past_key_values is not None else None - ) + # No need to slice this down + past_key_values = batch.past_key_values else: # Concatenate tensors input_ids = torch.cat(batch.input_ids).view(-1) @@ -392,7 +409,18 @@ class FlashCausalLM(Model): # Initialize past_key_values in prefill if batch.past_key_values is None: - batch.past_key_values = [None] * len(batch) + # Initialize past padding tensor + if self.past_pad is None: + self.past_pad = present.new_zeros(present.shape[0], 1, *present.shape[2:]) + # Set in batch in case it needs to be used later in concatenate() + batch.past_pad = self.past_pad + if len(batch) == 1: + # Preallocate tensor for bs = 1 case + batch.past_key_values = torch.nn.functional.pad( + present, (0, 0, 0, 0, 0, 0, 0, batch.stopping_criterias[0].max_new_tokens) + ) + else: + batch.past_key_values = [None, self.past_pad] * len(batch) # Cumulative length cumulative_length = 0 @@ -477,21 +505,10 @@ class FlashCausalLM(Model): generated_text = GeneratedText( output_text, stopping_criteria.current_tokens, reason, seed ) - - # CAUTION: generation will be stopped so no need to pad - # This will make the next forward crash if the request does not get filtered - new_input_length = input_length - past = present[:, start_index:end_index] else: stopped = False generated_text = None - # Pad present for next iter attention - new_input_length = input_length + 1 - past = torch.nn.functional.pad( - present[:, start_index:end_index], (0, 0, 0, 0, 0, 0, 0, 1) - ) - # Prefill if prefill: # Remove generated token to only have prefill and add nan for first prompt token @@ -522,6 +539,7 @@ class FlashCausalLM(Model): generations.append(generation) cumulative_length += input_length + new_input_length = input_length + 1 # Update values batch.input_ids[i] = next_token_id @@ -532,7 +550,8 @@ class FlashCausalLM(Model): batch.all_input_ids[i] = all_input_ids batch.all_input_ids_tensor[i] = all_input_ids_tensor batch.max_seqlen = max(batch.max_seqlen, new_input_length) - batch.past_key_values[i] = past + if len(batch) != 1: + batch.past_key_values[i * 2] = present[:, start_index:end_index] # Cumulative sum batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length