feat(server): use encoding to get prefill tokens

This commit is contained in:
OlivierDehaene 2023-06-20 18:29:55 +02:00
parent 53aa9194c8
commit 83e442ca9a
4 changed files with 56 additions and 21 deletions

View File

@ -150,15 +150,27 @@ impl Infer {
match response? { match response? {
// Add prefill tokens // Add prefill tokens
InferStreamResponse::Prefill(tokens) => { InferStreamResponse::Prefill(tokens) => {
let tokens_length = tokens.ids.len();
// Validation
// logprobs, ids and texts must have the same lengths
if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length {
return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths")))
}
result_prefill = Vec::with_capacity(tokens_length);
// Create Token objects // Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster // We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens for ((id, logprob), text) in tokens.ids.into_iter().zip(
.ids tokens.logprobs.into_iter()
.into_iter() ).zip(tokens.texts.into_iter()) {
.zip(tokens.logprobs.into_iter()) result_prefill.push(PrefillToken{
.zip(tokens.texts.into_iter()) id,
.map(|((id, logprob), text)| PrefillToken { id, text, logprob }) text,
.collect(); logprob,
});
}
} }
// Push last token // Push last token
InferStreamResponse::Token(token) => result_tokens.push(token), InferStreamResponse::Token(token) => result_tokens.push(token),

View File

@ -1,8 +1,8 @@
import torch import torch
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from tokenizers import Encoding
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
@ -53,6 +53,9 @@ class CausalLMBatch(Batch):
# Past metadata # Past metadata
keys_head_dim_last: bool = True keys_head_dim_last: bool = True
# Input encodings
encodings: Optional[List[Encoding]] = None
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -141,6 +144,7 @@ class CausalLMBatch(Batch):
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
encodings=tokenized_inputs.encodings,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -625,11 +629,16 @@ class CausalLM(Model):
-new_input_length:-1 -new_input_length:-1
].tolist() ].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode( if batch.encodings is not None:
prefill_token_ids, prefill_texts = batch.encodings[i].tokens[
clean_up_tokenization_spaces=False, -new_input_length - 1 :
skip_special_tokens=False, ]
) else:
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts prefill_token_ids, prefill_logprobs, prefill_texts
) )

View File

@ -5,6 +5,7 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from tokenizers import Encoding
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union, Dict from typing import Optional, Tuple, List, Type, Union, Dict
@ -72,6 +73,9 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to # Maximum number of tokens this batch will grow to
max_tokens: int max_tokens: int
# Input encodings
encodings: Optional[List[Encoding]] = None
def to_pb(self) -> generate_pb2.CachedBatch: def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch( return generate_pb2.CachedBatch(
id=self.batch_id, id=self.batch_id,
@ -94,9 +98,9 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(r.inputs) batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate) max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer( batch_encoding = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"] )
position_ids = [] position_ids = []
past_present_indices = [] past_present_indices = []
@ -130,7 +134,7 @@ class FlashCausalLMBatch(Batch):
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_encoding["input_ids"])
): ):
# request id -> idx in list mapping # request id -> idx in list mapping
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
@ -282,6 +286,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length, max_tokens=cumulative_max_length,
encodings=batch_encoding.encodings,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -822,11 +827,18 @@ class FlashCausalLM(Model):
out_start_index : out_end_index - 1 out_start_index : out_end_index - 1
] ]
prefill_token_ids = all_input_ids[:-1] prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, if batch.encodings is not None:
clean_up_tokenization_spaces=False, prefill_texts = batch.encodings[i].tokens[
skip_special_tokens=False, -len(all_input_ids) - 1 :
) ]
else:
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts prefill_token_ids, request_prefill_logprobs, prefill_texts
) )

View File

@ -61,6 +61,8 @@ class PrefillTokens:
token_ids: List[int] token_ids: List[int]
logprobs: List[float] logprobs: List[float]
texts: List[str] texts: List[str]
token_is_special: Optional[List[bool]] = None
offsets: Optional[List[bool]] = None
def to_pb(self) -> generate_pb2.PrefillTokens: def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens( return generate_pb2.PrefillTokens(