feat(server): use encoding to get prefill tokens
This commit is contained in:
parent
53aa9194c8
commit
83e442ca9a
|
@ -150,15 +150,27 @@ impl Infer {
|
|||
match response? {
|
||||
// Add prefill tokens
|
||||
InferStreamResponse::Prefill(tokens) => {
|
||||
let tokens_length = tokens.ids.len();
|
||||
|
||||
// Validation
|
||||
// logprobs, ids and texts must have the same lengths
|
||||
if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length {
|
||||
return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths")))
|
||||
}
|
||||
|
||||
result_prefill = Vec::with_capacity(tokens_length);
|
||||
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
result_prefill = tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens.logprobs.into_iter())
|
||||
.zip(tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
for ((id, logprob), text) in tokens.ids.into_iter().zip(
|
||||
tokens.logprobs.into_iter()
|
||||
).zip(tokens.texts.into_iter()) {
|
||||
result_prefill.push(PrefillToken{
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
});
|
||||
}
|
||||
}
|
||||
// Push last token
|
||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from tokenizers import Encoding
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
|
@ -53,6 +53,9 @@ class CausalLMBatch(Batch):
|
|||
# Past metadata
|
||||
keys_head_dim_last: bool = True
|
||||
|
||||
# Input encodings
|
||||
encodings: Optional[List[Encoding]] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
|
@ -141,6 +144,7 @@ class CausalLMBatch(Batch):
|
|||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
encodings=tokenized_inputs.encodings,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -625,11 +629,16 @@ class CausalLM(Model):
|
|||
-new_input_length:-1
|
||||
].tolist()
|
||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
if batch.encodings is not None:
|
||||
prefill_texts = batch.encodings[i].tokens[
|
||||
-new_input_length - 1 :
|
||||
]
|
||||
else:
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from tokenizers import Encoding
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
|
@ -72,6 +73,9 @@ class FlashCausalLMBatch(Batch):
|
|||
# Maximum number of tokens this batch will grow to
|
||||
max_tokens: int
|
||||
|
||||
# Input encodings
|
||||
encodings: Optional[List[Encoding]] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
|
@ -94,9 +98,9 @@ class FlashCausalLMBatch(Batch):
|
|||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_encoding = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
position_ids = []
|
||||
past_present_indices = []
|
||||
|
@ -130,7 +134,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
zip(pb.requests, batch_encoding["input_ids"])
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
@ -282,6 +286,7 @@ class FlashCausalLMBatch(Batch):
|
|||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
max_tokens=cumulative_max_length,
|
||||
encodings=batch_encoding.encodings,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -822,11 +827,18 @@ class FlashCausalLM(Model):
|
|||
out_start_index : out_end_index - 1
|
||||
]
|
||||
prefill_token_ids = all_input_ids[:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
if batch.encodings is not None:
|
||||
prefill_texts = batch.encodings[i].tokens[
|
||||
-len(all_input_ids) - 1 :
|
||||
]
|
||||
else:
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||
)
|
||||
|
|
|
@ -61,6 +61,8 @@ class PrefillTokens:
|
|||
token_ids: List[int]
|
||||
logprobs: List[float]
|
||||
texts: List[str]
|
||||
token_is_special: Optional[List[bool]] = None
|
||||
offsets: Optional[List[bool]] = None
|
||||
|
||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||
return generate_pb2.PrefillTokens(
|
||||
|
|
Loading…
Reference in New Issue