feat(server): use encoding to get prefill tokens
This commit is contained in:
parent
53aa9194c8
commit
83e442ca9a
|
@ -150,15 +150,27 @@ impl Infer {
|
||||||
match response? {
|
match response? {
|
||||||
// Add prefill tokens
|
// Add prefill tokens
|
||||||
InferStreamResponse::Prefill(tokens) => {
|
InferStreamResponse::Prefill(tokens) => {
|
||||||
|
let tokens_length = tokens.ids.len();
|
||||||
|
|
||||||
|
// Validation
|
||||||
|
// logprobs, ids and texts must have the same lengths
|
||||||
|
if tokens.logprobs.len() != tokens_length || tokens.texts.len() != tokens_length {
|
||||||
|
return Err(InferError::GenerationError(format!("Prefill tokens do not have the correct lengths")))
|
||||||
|
}
|
||||||
|
|
||||||
|
result_prefill = Vec::with_capacity(tokens_length);
|
||||||
|
|
||||||
// Create Token objects
|
// Create Token objects
|
||||||
// We do that here instead of in the Python code as Rust for loops are faster
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
result_prefill = tokens
|
for ((id, logprob), text) in tokens.ids.into_iter().zip(
|
||||||
.ids
|
tokens.logprobs.into_iter()
|
||||||
.into_iter()
|
).zip(tokens.texts.into_iter()) {
|
||||||
.zip(tokens.logprobs.into_iter())
|
result_prefill.push(PrefillToken{
|
||||||
.zip(tokens.texts.into_iter())
|
id,
|
||||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
text,
|
||||||
.collect();
|
logprob,
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Push last token
|
// Push last token
|
||||||
InferStreamResponse::Token(token) => result_tokens.push(token),
|
InferStreamResponse::Token(token) => result_tokens.push(token),
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
from tokenizers import Encoding
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
|
@ -53,6 +53,9 @@ class CausalLMBatch(Batch):
|
||||||
# Past metadata
|
# Past metadata
|
||||||
keys_head_dim_last: bool = True
|
keys_head_dim_last: bool = True
|
||||||
|
|
||||||
|
# Input encodings
|
||||||
|
encodings: Optional[List[Encoding]] = None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
|
@ -141,6 +144,7 @@ class CausalLMBatch(Batch):
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
encodings=tokenized_inputs.encodings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -625,11 +629,16 @@ class CausalLM(Model):
|
||||||
-new_input_length:-1
|
-new_input_length:-1
|
||||||
].tolist()
|
].tolist()
|
||||||
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
prefill_token_ids = all_input_ids[-new_input_length:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
if batch.encodings is not None:
|
||||||
prefill_token_ids,
|
prefill_texts = batch.encodings[i].tokens[
|
||||||
clean_up_tokenization_spaces=False,
|
-new_input_length - 1 :
|
||||||
skip_special_tokens=False,
|
]
|
||||||
)
|
else:
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
prefill_token_ids, prefill_logprobs, prefill_texts
|
prefill_token_ids, prefill_logprobs, prefill_texts
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
|
from tokenizers import Encoding
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||||
|
|
||||||
|
@ -72,6 +73,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
|
# Input encodings
|
||||||
|
encodings: Optional[List[Encoding]] = None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
|
@ -94,9 +98,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
batch_inputs.append(r.inputs)
|
batch_inputs.append(r.inputs)
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
|
|
||||||
batch_tokenized_inputs = tokenizer(
|
batch_encoding = tokenizer(
|
||||||
batch_inputs, truncation=True, max_length=max_truncation
|
batch_inputs, truncation=True, max_length=max_truncation
|
||||||
)["input_ids"]
|
)
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
past_present_indices = []
|
past_present_indices = []
|
||||||
|
@ -130,7 +134,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_encoding["input_ids"])
|
||||||
):
|
):
|
||||||
# request id -> idx in list mapping
|
# request id -> idx in list mapping
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
|
@ -282,6 +286,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
max_tokens=cumulative_max_length,
|
max_tokens=cumulative_max_length,
|
||||||
|
encodings=batch_encoding.encodings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -822,11 +827,18 @@ class FlashCausalLM(Model):
|
||||||
out_start_index : out_end_index - 1
|
out_start_index : out_end_index - 1
|
||||||
]
|
]
|
||||||
prefill_token_ids = all_input_ids[:-1]
|
prefill_token_ids = all_input_ids[:-1]
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
|
||||||
prefill_token_ids,
|
if batch.encodings is not None:
|
||||||
clean_up_tokenization_spaces=False,
|
prefill_texts = batch.encodings[i].tokens[
|
||||||
skip_special_tokens=False,
|
-len(all_input_ids) - 1 :
|
||||||
)
|
]
|
||||||
|
else:
|
||||||
|
prefill_texts = self.tokenizer.batch_decode(
|
||||||
|
prefill_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
prefill_tokens = PrefillTokens(
|
prefill_tokens = PrefillTokens(
|
||||||
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
prefill_token_ids, request_prefill_logprobs, prefill_texts
|
||||||
)
|
)
|
||||||
|
|
|
@ -61,6 +61,8 @@ class PrefillTokens:
|
||||||
token_ids: List[int]
|
token_ids: List[int]
|
||||||
logprobs: List[float]
|
logprobs: List[float]
|
||||||
texts: List[str]
|
texts: List[str]
|
||||||
|
token_is_special: Optional[List[bool]] = None
|
||||||
|
offsets: Optional[List[bool]] = None
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.PrefillTokens:
|
def to_pb(self) -> generate_pb2.PrefillTokens:
|
||||||
return generate_pb2.PrefillTokens(
|
return generate_pb2.PrefillTokens(
|
||||||
|
|
Loading…
Reference in New Issue