feat(server): add special token bool (#85)

This commit is contained in:
OlivierDehaene 2023-02-24 15:55:57 +01:00 committed by GitHub
parent 4b1c9720c0
commit 0ac184ce77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 134 additions and 72 deletions

View File

@ -1,122 +1,142 @@
{ {
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": null,
"prefill": [ "prefill": [
{ {
"id": 10264, "id": 10264,
"logprob": null, "text": "Test",
"text": "Test" "logprob": null
}, },
{ {
"id": 8821, "id": 8821,
"logprob": -11.894989, "text": " request",
"text": " request" "logprob": -11.894989
} }
], ],
"seed": null,
"tokens": [ "tokens": [
{ {
"id": 17, "id": 17,
"text": ".",
"logprob": -1.8267672, "logprob": -1.8267672,
"text": "." "special": false
}, },
{ {
"id": 1587, "id": 1587,
"text": "get",
"logprob": -2.4674969, "logprob": -2.4674969,
"text": "get" "special": false
}, },
{ {
"id": 11, "id": 11,
"text": "(",
"logprob": -1.906001, "logprob": -1.906001,
"text": "(" "special": false
}, },
{ {
"id": 5, "id": 5,
"text": "\"",
"logprob": -1.2279545, "logprob": -1.2279545,
"text": "\"" "special": false
}, },
{ {
"id": 4899, "id": 4899,
"text": "action",
"logprob": -4.170299, "logprob": -4.170299,
"text": "action" "special": false
}, },
{ {
"id": 5, "id": 5,
"text": "\"",
"logprob": -0.32478866, "logprob": -0.32478866,
"text": "\"" "special": false
}, },
{ {
"id": 12, "id": 12,
"text": ")",
"logprob": -1.0773665, "logprob": -1.0773665,
"text": ")" "special": false
}, },
{ {
"id": 30, "id": 30,
"text": ";",
"logprob": -0.27640742, "logprob": -0.27640742,
"text": ";" "special": false
}, },
{ {
"id": 837, "id": 837,
"text": "\n ",
"logprob": -1.6970354, "logprob": -1.6970354,
"text": "\n " "special": false
}, },
{ {
"id": 1320, "id": 1320,
"text": " if",
"logprob": -1.4495516, "logprob": -1.4495516,
"text": " if" "special": false
}, },
{ {
"id": 375, "id": 375,
"text": " (",
"logprob": -0.23609057, "logprob": -0.23609057,
"text": " (" "special": false
}, },
{ {
"id": 4899, "id": 4899,
"text": "action",
"logprob": -1.1916996, "logprob": -1.1916996,
"text": "action" "special": false
}, },
{ {
"id": 3535, "id": 3535,
"text": " ==",
"logprob": -0.8918753, "logprob": -0.8918753,
"text": " ==" "special": false
}, },
{ {
"id": 5109, "id": 5109,
"text": " null",
"logprob": -0.3933342, "logprob": -0.3933342,
"text": " null" "special": false
}, },
{ {
"id": 12, "id": 12,
"text": ")",
"logprob": -0.43212673, "logprob": -0.43212673,
"text": ")" "special": false
}, },
{ {
"id": 731, "id": 731,
"text": " {",
"logprob": -0.17702064, "logprob": -0.17702064,
"text": " {" "special": false
}, },
{ {
"id": 1260, "id": 1260,
"text": "\n ",
"logprob": -0.07027565, "logprob": -0.07027565,
"text": "\n " "special": false
}, },
{ {
"id": 10519, "id": 10519,
"text": " throw",
"logprob": -1.3915029, "logprob": -1.3915029,
"text": " throw" "special": false
}, },
{ {
"id": 2084, "id": 2084,
"text": " new",
"logprob": -0.04201372, "logprob": -0.04201372,
"text": " new" "special": false
}, },
{ {
"id": 150858, "id": 150858,
"text": " RuntimeException",
"logprob": -1.7329919, "logprob": -1.7329919,
"text": " RuntimeException" "special": false
} }
] ]
}, }
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
} }

View File

@ -14,6 +14,7 @@ pub struct Token {
id: u32, id: u32,
text: String, text: String,
logprob: Option<f32>, logprob: Option<f32>,
special: bool,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) {
{ {
assert_eq!(token.id, expected_token.id); assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text); assert_eq!(token.text, expected_token.text);
assert_eq!(token.special, expected_token.special);
if let Some(logprob) = token.logprob { if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap(); let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001); assert_float_eq!(logprob, expected_logprob, abs <= 0.001);

View File

@ -1,117 +1,137 @@
{ {
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
"details": { "details": {
"finish_reason": "length", "finish_reason": "length",
"generated_tokens": 20, "generated_tokens": 20,
"seed": null,
"prefill": [ "prefill": [
{ {
"id": 0, "id": 0,
"logprob": null, "text": "<pad>",
"text": "<pad>" "logprob": null
} }
], ],
"seed": null,
"tokens": [ "tokens": [
{ {
"id": 259, "id": 259,
"text": "",
"logprob": -1.3656927, "logprob": -1.3656927,
"text": "" "special": false
}, },
{ {
"id": 215100, "id": 215100,
"text": "\"\"\"",
"logprob": -2.6551573, "logprob": -2.6551573,
"text": "\"\"\"" "special": false
}, },
{ {
"id": 46138, "id": 46138,
"text": "Test",
"logprob": -1.8059857, "logprob": -1.8059857,
"text": "Test" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": "the",
"logprob": -1.2102449, "logprob": -1.2102449,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": "",
"logprob": -1.6057279, "logprob": -1.6057279,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -3.6060903, "logprob": -3.6060903,
"text": "contents" "special": false
}, },
{ {
"id": 304, "id": 304,
"text": "of",
"logprob": -0.5270343, "logprob": -0.5270343,
"text": "of" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": "the",
"logprob": -0.62522805, "logprob": -0.62522805,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": "",
"logprob": -1.4069618, "logprob": -1.4069618,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -2.621994, "logprob": -2.621994,
"text": "contents" "special": false
}, },
{ {
"id": 304, "id": 304,
"text": "of",
"logprob": -1.3172221, "logprob": -1.3172221,
"text": "of" "special": false
}, },
{ {
"id": 287, "id": 287,
"text": "the",
"logprob": -0.3501925, "logprob": -0.3501925,
"text": "the" "special": false
}, },
{ {
"id": 259, "id": 259,
"text": "",
"logprob": -0.7219573, "logprob": -0.7219573,
"text": "" "special": false
}, },
{ {
"id": 49076, "id": 49076,
"text": "contents",
"logprob": -1.0494149, "logprob": -1.0494149,
"text": "contents" "special": false
}, },
{ {
"id": 260, "id": 260,
"text": ".",
"logprob": -1.0803378, "logprob": -1.0803378,
"text": "." "special": false
}, },
{ {
"id": 259, "id": 259,
"text": "",
"logprob": -0.32933083, "logprob": -0.32933083,
"text": "" "special": false
}, },
{ {
"id": 215100, "id": 215100,
"text": "\"\"\"",
"logprob": -0.11268901, "logprob": -0.11268901,
"text": "\"\"\"" "special": false
}, },
{ {
"id": 2978, "id": 2978,
"text": "test",
"logprob": -1.5846587, "logprob": -1.5846587,
"text": "test" "special": false
}, },
{ {
"id": 290, "id": 290,
"text": "_",
"logprob": -0.49796978, "logprob": -0.49796978,
"text": "_" "special": false
}, },
{ {
"id": 4125, "id": 4125,
"text": "test",
"logprob": -2.0026445, "logprob": -2.0026445,
"text": "test" "special": false
} }
] ]
}, }
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test"
} }

View File

@ -108,8 +108,10 @@ message Generation {
float token_logprob = 4; float token_logprob = 4;
/// Text /// Text
string token_text = 5; string token_text = 5;
/// Is it a special token
bool token_is_special = 6;
/// Complete generated text /// Complete generated text
GeneratedText generated_text = 6; GeneratedText generated_text = 7;
} }
message PrefillRequest { message PrefillRequest {

View File

@ -1,7 +1,7 @@
/// Batching and inference logic /// Batching and inference logic
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::GenerateRequest;
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
@ -138,7 +138,7 @@ impl Infer {
.into_iter() .into_iter()
.zip(tokens.logprobs.into_iter()) .zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter()) .zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| Token { id, text, logprob }) .map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect(); .collect();
} }
// Push last token // Push last token
@ -372,6 +372,7 @@ fn send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entr
id: generation.token_id, id: generation.token_id,
text: generation.token_text, text: generation.token_text,
logprob: generation.token_logprob, logprob: generation.token_logprob,
special: generation.token_is_special,
}; };
if let Some(generated_text) = generation.generated_text { if let Some(generated_text) = generation.generated_text {
@ -420,7 +421,7 @@ pub(crate) enum InferStreamResponse {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct InferResponse { pub(crate) struct InferResponse {
pub(crate) prefill: Vec<Token>, pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>, pub(crate) tokens: Vec<Token>,
pub(crate) generated_text: GeneratedText, pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant, pub(crate) queued: Instant,

View File

@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest {
pub parameters: GenerateParameters, pub parameters: GenerateParameters,
} }
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
#[schema(example = "test")]
text: String,
#[schema(nullable = true, example = -0.34)]
logprob: f32,
}
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct Token { pub struct Token {
#[schema(example = 0)] #[schema(example = 0)]
@ -94,6 +104,8 @@ pub struct Token {
text: String, text: String,
#[schema(nullable = true, example = -0.34)] #[schema(nullable = true, example = -0.34)]
logprob: f32, logprob: f32,
#[schema(example = "false")]
special: bool,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -116,7 +128,7 @@ pub(crate) struct Details {
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Option<Vec<Token>>, pub prefill: Option<Vec<PrefillToken>>,
pub tokens: Option<Vec<Token>>, pub tokens: Option<Vec<Token>>,
} }

View File

@ -2,7 +2,7 @@
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferStreamResponse};
use crate::{ use crate::{
Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
Infer, StreamDetails, StreamResponse, Token, Validation, Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -255,11 +255,11 @@ async fn generate_stream(
let time_per_token = inference_time / generated_text.generated_tokens; let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata // Tracing metadata
span.record("total_time", format!("{:?}", total_time)); span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{:?}", validation_time)); span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{:?}", queue_time)); span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{:?}", inference_time)); span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{:?}", time_per_token)); span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed)); span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text); tracing::info!(parent: &span, "Output: {}", generated_text.text);
@ -349,6 +349,7 @@ pub async fn run(
schemas( schemas(
GenerateRequest, GenerateRequest,
GenerateParameters, GenerateParameters,
PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
Details, Details,

View File

@ -172,7 +172,9 @@ class CausalLMBatch(Batch):
# and to remove unused allocated space # and to remove unused allocated space
left_offset = max_sequence_length - batch.max_sequence_length left_offset = max_sequence_length - batch.max_sequence_length
batch_left_offset = ( batch_left_offset = (
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset batch.attention_mask.shape[1]
- batch.max_sequence_length
- batch.padding_right_offset
) )
attention_mask[ attention_mask[
start_index:end_index, start_index:end_index,
@ -443,6 +445,7 @@ class CausalLM(Model):
next_token_id_squeezed, next_token_id_squeezed,
next_token_logprob, next_token_logprob,
next_token_text, next_token_text,
next_token_id_squeezed in self.all_special_ids,
generated_text, generated_text,
) )

View File

@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch)
class Model(ABC): class Model(ABC):
def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.all_special_ids = set(tokenizer.all_special_ids)
self.device = device self.device = device
@property @property

View File

@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch):
else: else:
batch_left_offset = ( batch_left_offset = (
batch.decoder_attention_mask.shape[1] batch.decoder_attention_mask.shape[1]
- batch.max_decoder_input_length - batch.padding_right_offset - batch.max_decoder_input_length
- batch.padding_right_offset
) )
decoder_attention_mask[ decoder_attention_mask[
start_index:end_index, start_index:end_index,
@ -494,14 +495,10 @@ class Seq2SeqLM(Model):
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1]
prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)
prefill_tokens = PrefillTokens( prefill_tokens = PrefillTokens(
prefill_token_ids, [float("nan")], prefill_texts [self.tokenizer.bos_token_id],
[float("nan")],
[self.tokenizer.bos_token],
) )
else: else:
prefill_tokens = None prefill_tokens = None
@ -512,6 +509,7 @@ class Seq2SeqLM(Model):
next_token_id_squeezed, next_token_id_squeezed,
next_token_logprob, next_token_logprob,
next_token_text, next_token_text,
next_token_id_squeezed in self.all_special_ids,
generated_text, generated_text,
) )

View File

@ -73,6 +73,7 @@ class Generation:
token_id: int token_id: int
token_logprob: float token_logprob: float
token_text: str token_text: str
token_is_special: bool
generated_text: Optional[GeneratedText] generated_text: Optional[GeneratedText]
def to_pb(self) -> generate_pb2.Generation: def to_pb(self) -> generate_pb2.Generation:
@ -84,6 +85,7 @@ class Generation:
token_id=self.token_id, token_id=self.token_id,
token_logprob=self.token_logprob, token_logprob=self.token_logprob,
token_text=self.token_text, token_text=self.token_text,
token_is_special=self.token_is_special,
generated_text=self.generated_text.to_pb() generated_text=self.generated_text.to_pb()
if self.generated_text is not None if self.generated_text is not None
else None, else None,