From 83e442ca9a764ce67aa0a58906b8a2c92adda162 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:29:55 +0200 Subject: [PATCH] feat(server): use encoding to get prefill tokens --- router/src/infer.rs | 26 ++++++++++++----- .../models/causal_lm.py | 21 ++++++++++---- .../models/flash_causal_lm.py | 28 +++++++++++++------ server/text_generation_server/models/types.py | 2 ++ 4 files changed, 56 insertions(+), 21 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 00fa2818..048acfb2 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -150,15 +150,27 @@ impl Infer { match response? { // Add prefill tokens InferStreamResponse::Prefill(tokens) => { + let tokens_length = tokens.ids.len(); + + // Validation + // logprobs, ids and texts must have the same lengths + if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length { + return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths"))) + } + + result_prefill = Vec::with_capacity(tokens_length); + // Create Token objects // We do that here instead of in the Python code as Rust for loops are faster - result_prefill = tokens - .ids - .into_iter() - .zip(tokens.logprobs.into_iter()) - .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); + for ((id, logprob), text) in tokens.ids.into_iter().zip( + tokens.logprobs.into_iter() + ).zip(tokens.texts.into_iter()) { + result_prefill.push(PrefillToken{ + id, + text, + logprob, + }); + } } // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba0853f5..151e0b63 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,8 +1,8 @@ import torch -import inspect from dataclasses import dataclass from opentelemetry import trace +from tokenizers import Encoding from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from typing import Optional, Tuple, List, Type, Dict @@ -53,6 +53,9 @@ class CausalLMBatch(Batch): # Past metadata keys_head_dim_last: bool = True + # Input encodings + encodings: Optional[List[Encoding]] = None + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -141,6 +144,7 @@ class CausalLMBatch(Batch): max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, + encodings=tokenized_inputs.encodings, ) @tracer.start_as_current_span("filter") @@ -625,11 +629,16 @@ class CausalLM(Model): -new_input_length:-1 ].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + if batch.encodings is not None: + prefill_texts = batch.encodings[i].tokens[ + -new_input_length - 1 : + ] + else: + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) prefill_tokens = PrefillTokens( prefill_token_ids, prefill_logprobs, prefill_texts ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ecea998e..f9963e9d 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -5,6 +5,7 @@ import numpy as np from dataclasses import dataclass from opentelemetry import trace +from tokenizers import Encoding from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from typing import Optional, Tuple, List, Type, Union, Dict @@ -72,6 +73,9 @@ class FlashCausalLMBatch(Batch): # Maximum number of tokens this batch will grow to max_tokens: int + # Input encodings + encodings: Optional[List[Encoding]] = None + def to_pb(self) -> generate_pb2.CachedBatch: return generate_pb2.CachedBatch( id=self.batch_id, @@ -94,9 +98,9 @@ class FlashCausalLMBatch(Batch): batch_inputs.append(r.inputs) max_truncation = max(max_truncation, r.truncate) - batch_tokenized_inputs = tokenizer( + batch_encoding = tokenizer( batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] + ) position_ids = [] past_present_indices = [] @@ -130,7 +134,7 @@ class FlashCausalLMBatch(Batch): # Parse batch for i, (r, tokenized_input) in enumerate( - zip(pb.requests, batch_tokenized_inputs) + zip(pb.requests, batch_encoding["input_ids"]) ): # request id -> idx in list mapping requests_idx_mapping[r.id] = i @@ -282,6 +286,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, max_tokens=cumulative_max_length, + encodings=batch_encoding.encodings, ) @tracer.start_as_current_span("filter") @@ -822,11 +827,18 @@ class FlashCausalLM(Model): out_start_index : out_end_index - 1 ] prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) + + if batch.encodings is not None: + prefill_texts = batch.encodings[i].tokens[ + -len(all_input_ids) - 1 : + ] + else: + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( prefill_token_ids, request_prefill_logprobs, prefill_texts ) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 28ca8147..df5b576d 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -61,6 +61,8 @@ class PrefillTokens: token_ids: List[int] logprobs: List[float] texts: List[str] + token_is_special: Optional[List[bool]] = None + offsets: Optional[List[bool]] = None def to_pb(self) -> generate_pb2.PrefillTokens: return generate_pb2.PrefillTokens(