feat(server): use encoding to get prefill tokens

This commit is contained in:
OlivierDehaene 2023-06-20 18:29:55 +02:00
parent 53aa9194c8
commit 83e442ca9a
4 changed files with 56 additions and 21 deletions

View File

@ -150,15 +150,27 @@ impl Infer {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
let tokens_length = tokens.ids.len();
// Validation
// logprobs, ids and texts must have the same lengths
if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length {
return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths")))
}
result_prefill = Vec::with_capacity(tokens_length);
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
for ((id, logprob), text) in tokens.ids.into_iter().zip(
tokens.logprobs.into_iter()
).zip(tokens.texts.into_iter()) {
result_prefill.push(PrefillToken{
id,
text,
logprob,
});
}
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),

View File

@ -1,8 +1,8 @@
import torch
import inspect
from dataclasses import dataclass
from opentelemetry import trace
from tokenizers import Encoding
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
@ -53,6 +53,9 @@ class CausalLMBatch(Batch):
# Past metadata
keys_head_dim_last: bool = True
# Input encodings
encodings: Optional[List[Encoding]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -141,6 +144,7 @@ class CausalLMBatch(Batch):
max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
encodings=tokenized_inputs.encodings,
)
@tracer.start_as_current_span("filter")
@ -625,11 +629,16 @@ class CausalLM(Model):
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
if batch.encodings is not None:
prefill_texts = batch.encodings[i].tokens[
-new_input_length - 1 :
]
else:
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, prefill_logprobs, prefill_texts
)

View File

@ -5,6 +5,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from tokenizers import Encoding
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from typing import Optional, Tuple, List, Type, Union, Dict
@ -72,6 +73,9 @@ class FlashCausalLMBatch(Batch):
# Maximum number of tokens this batch will grow to
max_tokens: int
# Input encodings
encodings: Optional[List[Encoding]] = None
def to_pb(self) -> generate_pb2.CachedBatch:
return generate_pb2.CachedBatch(
id=self.batch_id,
@ -94,9 +98,9 @@ class FlashCausalLMBatch(Batch):
batch_inputs.append(r.inputs)
max_truncation = max(max_truncation, r.truncate)
batch_tokenized_inputs = tokenizer(
batch_encoding = tokenizer(
batch_inputs, truncation=True, max_length=max_truncation
)["input_ids"]
)
position_ids = []
past_present_indices = []
@ -130,7 +134,7 @@ class FlashCausalLMBatch(Batch):
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
zip(pb.requests, batch_encoding["input_ids"])
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
@ -282,6 +286,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length,
encodings=batch_encoding.encodings,
)
@tracer.start_as_current_span("filter")
@ -822,11 +827,18 @@ class FlashCausalLM(Model):
out_start_index : out_end_index - 1
]
prefill_token_ids = all_input_ids[:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
if batch.encodings is not None:
prefill_texts = batch.encodings[i].tokens[
-len(all_input_ids) - 1 :
]
else:
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, request_prefill_logprobs, prefill_texts
)

View File

@ -61,6 +61,8 @@ class PrefillTokens:
token_ids: List[int]
logprobs: List[float]
texts: List[str]
token_is_special: Optional[List[bool]] = None
offsets: Optional[List[bool]] = None
def to_pb(self) -> generate_pb2.PrefillTokens:
return generate_pb2.PrefillTokens(