feat(server): add special token bool (#85)
This commit is contained in:
parent
4b1c9720c0
commit
0ac184ce77
|
@ -1,122 +1,142 @@
|
|||
{
|
||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 10264,
|
||||
"logprob": null,
|
||||
"text": "Test"
|
||||
"text": "Test",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 8821,
|
||||
"logprob": -11.894989,
|
||||
"text": " request"
|
||||
"text": " request",
|
||||
"logprob": -11.894989
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 17,
|
||||
"text": ".",
|
||||
"logprob": -1.8267672,
|
||||
"text": "."
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1587,
|
||||
"text": "get",
|
||||
"logprob": -2.4674969,
|
||||
"text": "get"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"text": "(",
|
||||
"logprob": -1.906001,
|
||||
"text": "("
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"text": "\"",
|
||||
"logprob": -1.2279545,
|
||||
"text": "\""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4899,
|
||||
"text": "action",
|
||||
"logprob": -4.170299,
|
||||
"text": "action"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"text": "\"",
|
||||
"logprob": -0.32478866,
|
||||
"text": "\""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"text": ")",
|
||||
"logprob": -1.0773665,
|
||||
"text": ")"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"text": ";",
|
||||
"logprob": -0.27640742,
|
||||
"text": ";"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 837,
|
||||
"text": "\n ",
|
||||
"logprob": -1.6970354,
|
||||
"text": "\n "
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1320,
|
||||
"text": " if",
|
||||
"logprob": -1.4495516,
|
||||
"text": " if"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 375,
|
||||
"text": " (",
|
||||
"logprob": -0.23609057,
|
||||
"text": " ("
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4899,
|
||||
"text": "action",
|
||||
"logprob": -1.1916996,
|
||||
"text": "action"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 3535,
|
||||
"text": " ==",
|
||||
"logprob": -0.8918753,
|
||||
"text": " =="
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5109,
|
||||
"text": " null",
|
||||
"logprob": -0.3933342,
|
||||
"text": " null"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"text": ")",
|
||||
"logprob": -0.43212673,
|
||||
"text": ")"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 731,
|
||||
"text": " {",
|
||||
"logprob": -0.17702064,
|
||||
"text": " {"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1260,
|
||||
"text": "\n ",
|
||||
"logprob": -0.07027565,
|
||||
"text": "\n "
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 10519,
|
||||
"text": " throw",
|
||||
"logprob": -1.3915029,
|
||||
"text": " throw"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2084,
|
||||
"text": " new",
|
||||
"logprob": -0.04201372,
|
||||
"text": " new"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 150858,
|
||||
"text": " RuntimeException",
|
||||
"logprob": -1.7329919,
|
||||
"text": " RuntimeException"
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ pub struct Token {
|
|||
id: u32,
|
||||
text: String,
|
||||
logprob: Option<f32>,
|
||||
special: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
|
|||
{
|
||||
assert_eq!(token.id, expected_token.id);
|
||||
assert_eq!(token.text, expected_token.text);
|
||||
assert_eq!(token.special, expected_token.special);
|
||||
if let Some(logprob) = token.logprob {
|
||||
let expected_logprob = expected_token.logprob.unwrap();
|
||||
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
|
||||
|
|
|
@ -1,117 +1,137 @@
|
|||
{
|
||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": null,
|
||||
"text": "<pad>"
|
||||
"text": "<pad>",
|
||||
"logprob": null
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -1.3656927,
|
||||
"text": ""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 215100,
|
||||
"text": "\"\"\"",
|
||||
"logprob": -2.6551573,
|
||||
"text": "\"\"\""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 46138,
|
||||
"text": "Test",
|
||||
"logprob": -1.8059857,
|
||||
"text": "Test"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": "the",
|
||||
"logprob": -1.2102449,
|
||||
"text": "the"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -1.6057279,
|
||||
"text": ""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -3.6060903,
|
||||
"text": "contents"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": "of",
|
||||
"logprob": -0.5270343,
|
||||
"text": "of"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": "the",
|
||||
"logprob": -0.62522805,
|
||||
"text": "the"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -1.4069618,
|
||||
"text": ""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -2.621994,
|
||||
"text": "contents"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": "of",
|
||||
"logprob": -1.3172221,
|
||||
"text": "of"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 287,
|
||||
"text": "the",
|
||||
"logprob": -0.3501925,
|
||||
"text": "the"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -0.7219573,
|
||||
"text": ""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 49076,
|
||||
"text": "contents",
|
||||
"logprob": -1.0494149,
|
||||
"text": "contents"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 260,
|
||||
"text": ".",
|
||||
"logprob": -1.0803378,
|
||||
"text": "."
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"text": "",
|
||||
"logprob": -0.32933083,
|
||||
"text": ""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 215100,
|
||||
"text": "\"\"\"",
|
||||
"logprob": -0.11268901,
|
||||
"text": "\"\"\""
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2978,
|
||||
"text": "test",
|
||||
"logprob": -1.5846587,
|
||||
"text": "test"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 290,
|
||||
"text": "_",
|
||||
"logprob": -0.49796978,
|
||||
"text": "_"
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4125,
|
||||
"text": "test",
|
||||
"logprob": -2.0026445,
|
||||
"text": "test"
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
|
||||
}
|
||||
}
|
|
@ -108,8 +108,10 @@ message Generation {
|
|||
float token_logprob = 4;
|
||||
/// Text
|
||||
string token_text = 5;
|
||||
/// Is it a special token
|
||||
bool token_is_special = 6;
|
||||
/// Complete generated text
|
||||
GeneratedText generated_text = 6;
|
||||
GeneratedText generated_text = 7;
|
||||
}
|
||||
|
||||
message PrefillRequest {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/// Batching and inference logic
|
||||
use crate::validation::{Validation, ValidationError};
|
||||
use crate::GenerateRequest;
|
||||
use crate::{Entry, Queue, Token};
|
||||
use crate::{GenerateRequest, PrefillToken};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{
|
||||
|
@ -138,7 +138,7 @@ impl Infer {
|
|||
.into_iter()
|
||||
.zip(tokens.logprobs.into_iter())
|
||||
.zip(tokens.texts.into_iter())
|
||||
.map(|((id, logprob), text)| Token { id, text, logprob })
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
}
|
||||
// Push last token
|
||||
|
@ -372,6 +372,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
|
|||
id: generation.token_id,
|
||||
text: generation.token_text,
|
||||
logprob: generation.token_logprob,
|
||||
special: generation.token_is_special,
|
||||
};
|
||||
|
||||
if let Some(generated_text) = generation.generated_text {
|
||||
|
@ -420,7 +421,7 @@ pub(crate) enum InferStreamResponse {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InferResponse {
|
||||
pub(crate) prefill: Vec<Token>,
|
||||
pub(crate) prefill: Vec<PrefillToken>,
|
||||
pub(crate) tokens: Vec<Token>,
|
||||
pub(crate) generated_text: GeneratedText,
|
||||
pub(crate) queued: Instant,
|
||||
|
|
|
@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest {
|
|||
pub parameters: GenerateParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct PrefillToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
#[schema(nullable = true, example = -0.34)]
|
||||
logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct Token {
|
||||
#[schema(example = 0)]
|
||||
|
@ -94,6 +104,8 @@ pub struct Token {
|
|||
text: String,
|
||||
#[schema(nullable = true, example = -0.34)]
|
||||
logprob: f32,
|
||||
#[schema(example = "false")]
|
||||
special: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, ToSchema)]
|
||||
|
@ -116,7 +128,7 @@ pub(crate) struct Details {
|
|||
pub generated_tokens: u32,
|
||||
#[schema(example = 42)]
|
||||
pub seed: Option<u64>,
|
||||
pub prefill: Option<Vec<Token>>,
|
||||
pub prefill: Option<Vec<PrefillToken>>,
|
||||
pub tokens: Option<Vec<Token>>,
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
use crate::infer::{InferError, InferStreamResponse};
|
||||
use crate::{
|
||||
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
|
||||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
|
@ -255,11 +255,11 @@ async fn generate_stream(
|
|||
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||
|
||||
// Tracing metadata
|
||||
span.record("total_time", format!("{:?}", total_time));
|
||||
span.record("validation_time", format!("{:?}", validation_time));
|
||||
span.record("queue_time", format!("{:?}", queue_time));
|
||||
span.record("inference_time", format!("{:?}", inference_time));
|
||||
span.record("time_per_token", format!("{:?}", time_per_token));
|
||||
span.record("total_time", format!("{total_time:?}"));
|
||||
span.record("validation_time", format!("{validation_time:?}"));
|
||||
span.record("queue_time", format!("{queue_time:?}"));
|
||||
span.record("inference_time", format!("{inference_time:?}"));
|
||||
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||
|
||||
|
@ -349,6 +349,7 @@ pub async fn run(
|
|||
schemas(
|
||||
GenerateRequest,
|
||||
GenerateParameters,
|
||||
PrefillToken,
|
||||
Token,
|
||||
GenerateResponse,
|
||||
Details,
|
||||
|
|
|
@ -172,7 +172,9 @@ class CausalLMBatch(Batch):
|
|||
# and to remove unused allocated space
|
||||
left_offset = max_sequence_length - batch.max_sequence_length
|
||||
batch_left_offset = (
|
||||
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
|
||||
batch.attention_mask.shape[1]
|
||||
- batch.max_sequence_length
|
||||
- batch.padding_right_offset
|
||||
)
|
||||
attention_mask[
|
||||
start_index:end_index,
|
||||
|
@ -443,6 +445,7 @@ class CausalLM(Model):
|
|||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch)
|
|||
class Model(ABC):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
|
||||
self.tokenizer = tokenizer
|
||||
self.all_special_ids = set(tokenizer.all_special_ids)
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
|
|
|
@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch):
|
|||
else:
|
||||
batch_left_offset = (
|
||||
batch.decoder_attention_mask.shape[1]
|
||||
- batch.max_decoder_input_length - batch.padding_right_offset
|
||||
- batch.max_decoder_input_length
|
||||
- batch.padding_right_offset
|
||||
)
|
||||
decoder_attention_mask[
|
||||
start_index:end_index,
|
||||
|
@ -494,14 +495,10 @@ class Seq2SeqLM(Model):
|
|||
|
||||
# Prefill
|
||||
if stopping_criteria.current_tokens == 1:
|
||||
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
prefill_tokens = PrefillTokens(
|
||||
prefill_token_ids, [float("nan")], prefill_texts
|
||||
[self.tokenizer.bos_token_id],
|
||||
[float("nan")],
|
||||
[self.tokenizer.bos_token],
|
||||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
|
@ -512,6 +509,7 @@ class Seq2SeqLM(Model):
|
|||
next_token_id_squeezed,
|
||||
next_token_logprob,
|
||||
next_token_text,
|
||||
next_token_id_squeezed in self.all_special_ids,
|
||||
generated_text,
|
||||
)
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ class Generation:
|
|||
token_id: int
|
||||
token_logprob: float
|
||||
token_text: str
|
||||
token_is_special: bool
|
||||
generated_text: Optional[GeneratedText]
|
||||
|
||||
def to_pb(self) -> generate_pb2.Generation:
|
||||
|
@ -84,6 +85,7 @@ class Generation:
|
|||
token_id=self.token_id,
|
||||
token_logprob=self.token_logprob,
|
||||
token_text=self.token_text,
|
||||
token_is_special=self.token_is_special,
|
||||
generated_text=self.generated_text.to_pb()
|
||||
if self.generated_text is not None
|
||||
else None,
|
||||
|
|
Loading…
Reference in New Issue