feat(server): add special token bool (#85)

This commit is contained in:
OlivierDehaene 2023-02-24 15:55:57 +01:00 committed by GitHub
parent 4b1c9720c0
commit 0ac184ce77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 134 additions and 72 deletions

View File

@ -1,122 +1,142 @@
{
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 10264,
"logprob": null,
"text": "Test"
"text": "Test",
"logprob": null
},
{
"id": 8821,
"logprob": -11.894989,
"text": " request"
"text": " request",
"logprob": -11.894989
}
],
"seed": null,
"tokens": [
{
"id": 17,
"text": ".",
"logprob": -1.8267672,
"text": "."
"special": false
},
{
"id": 1587,
"text": "get",
"logprob": -2.4674969,
"text": "get"
"special": false
},
{
"id": 11,
"text": "(",
"logprob": -1.906001,
"text": "("
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -1.2279545,
"text": "\""
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -4.170299,
"text": "action"
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -0.32478866,
"text": "\""
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -1.0773665,
"text": ")"
"special": false
},
{
"id": 30,
"text": ";",
"logprob": -0.27640742,
"text": ";"
"special": false
},
{
"id": 837,
"text": "\n ",
"logprob": -1.6970354,
"text": "\n "
"special": false
},
{
"id": 1320,
"text": " if",
"logprob": -1.4495516,
"text": " if"
"special": false
},
{
"id": 375,
"text": " (",
"logprob": -0.23609057,
"text": " ("
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -1.1916996,
"text": "action"
"special": false
},
{
"id": 3535,
"text": " ==",
"logprob": -0.8918753,
"text": " =="
"special": false
},
{
"id": 5109,
"text": " null",
"logprob": -0.3933342,
"text": " null"
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -0.43212673,
"text": ")"
"special": false
},
{
"id": 731,
"text": " {",
"logprob": -0.17702064,
"text": " {"
"special": false
},
{
"id": 1260,
"text": "\n ",
"logprob": -0.07027565,
"text": "\n "
"special": false
},
{
"id": 10519,
"text": " throw",
"logprob": -1.3915029,
"text": " throw"
"special": false
},
{
"id": 2084,
"text": " new",
"logprob": -0.04201372,
"text": " new"
"special": false
},
{
"id": 150858,
"text": " RuntimeException",
"logprob": -1.7329919,
"text": " RuntimeException"
"special": false
}
]
},
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
}
}

View File

@ -14,6 +14,7 @@ pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
special: bool,
}
#[derive(Deserialize)]
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
{
assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text);
assert_eq!(token.special, expected_token.special);
if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);

View File

@ -1,117 +1,137 @@
{
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 0,
"logprob": null,
"text": "<pad>"
"text": "<pad>",
"logprob": null
}
],
"seed": null,
"tokens": [
{
"id": 259,
"text": "",
"logprob": -1.3656927,
"text": ""
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -2.6551573,
"text": "\"\"\""
"special": false
},
{
"id": 46138,
"text": "Test",
"logprob": -1.8059857,
"text": "Test"
"special": false
},
{
"id": 287,
"text": "the",
"logprob": -1.2102449,
"text": "the"
"special": false
},
{
"id": 259,
"text": "",
"logprob": -1.6057279,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -3.6060903,
"text": "contents"
"special": false
},
{
"id": 304,
"text": "of",
"logprob": -0.5270343,
"text": "of"
"special": false
},
{
"id": 287,
"text": "the",
"logprob": -0.62522805,
"text": "the"
"special": false
},
{
"id": 259,
"text": "",
"logprob": -1.4069618,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -2.621994,
"text": "contents"
"special": false
},
{
"id": 304,
"text": "of",
"logprob": -1.3172221,
"text": "of"
"special": false
},
{
"id": 287,
"text": "the",
"logprob": -0.3501925,
"text": "the"
"special": false
},
{
"id": 259,
"text": "",
"logprob": -0.7219573,
"text": ""
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -1.0494149,
"text": "contents"
"special": false
},
{
"id": 260,
"text": ".",
"logprob": -1.0803378,
"text": "."
"special": false
},
{
"id": 259,
"text": "",
"logprob": -0.32933083,
"text": ""
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -0.11268901,
"text": "\"\"\""
"special": false
},
{
"id": 2978,
"text": "test",
"logprob": -1.5846587,
"text": "test"
"special": false
},
{
"id": 290,
"text": "_",
"logprob": -0.49796978,
"text": "_"
"special": false
},
{
"id": 4125,
"text": "test",
"logprob": -2.0026445,
"text": "test"
"special": false
}
]
},
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
}
}

View File

@ -108,8 +108,10 @@ message Generation {
float token_logprob = 4;
/// Text
string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text
GeneratedText generated_text = 6;
GeneratedText generated_text = 7;
}
message PrefillRequest {

View File

@ -1,7 +1,7 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_client::{
@ -138,7 +138,7 @@ impl Infer {
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token { id, text, logprob })
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
// Push last token
@ -372,6 +372,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
id: generation.token_id,
text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
};
if let Some(generated_text) = generation.generated_text {
@ -420,7 +421,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>,
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,

View File

@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)]
pub struct Token {
#[schema(example = 0)]
@ -94,6 +104,8 @@ pub struct Token {
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
#[schema(example = "false")]
special: bool,
}
#[derive(Serialize, ToSchema)]
@ -116,7 +128,7 @@ pub(crate) struct Details {
pub generated_tokens: u32,
#[schema(example = 42)]
pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>,
pub prefill: Option<Vec<PrefillToken>>,
pub tokens: Option<Vec<Token>>,
}

View File

@ -2,7 +2,7 @@
use crate::infer::{InferError, InferStreamResponse};
use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Infer, StreamDetails, StreamResponse, Token, Validation,
Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
@ -255,11 +255,11 @@ async fn generate_stream(
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{:?}", total_time));
span.record("validation_time", format!("{:?}", validation_time));
span.record("queue_time", format!("{:?}", queue_time));
span.record("inference_time", format!("{:?}", inference_time));
span.record("time_per_token", format!("{:?}", time_per_token));
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
@ -349,6 +349,7 @@ pub async fn run(
schemas(
GenerateRequest,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
Details,

View File

@ -172,7 +172,9 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length
batch_left_offset = (
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
batch.attention_mask.shape[1]
- batch.max_sequence_length
- batch.padding_right_offset
)
attention_mask[
start_index:end_index,
@ -443,6 +445,7 @@ class CausalLM(Model):
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed in self.all_special_ids,
generated_text,
)

View File

@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device
@property

View File

@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch):
else:
batch_left_offset = (
batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length - batch.padding_right_offset
- batch.max_decoder_input_length
- batch.padding_right_offset
)
decoder_attention_mask[
start_index:end_index,
@ -494,14 +495,10 @@ class Seq2SeqLM(Model):
# Prefill
if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts
[self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
)
else:
prefill_tokens = None
@ -512,6 +509,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed,
next_token_logprob,
next_token_text,
next_token_id_squeezed in self.all_special_ids,
generated_text,
)

View File

@ -73,6 +73,7 @@ class Generation:
token_id: int
token_logprob: float
token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation:
@ -84,6 +85,7 @@ class Generation:
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb()
if self.generated_text is not None
else None,