diff --git a/launcher/tests/bloom_560m.json b/launcher/tests/bloom_560m.json index 17e2571e..96f89f6b 100644 --- a/launcher/tests/bloom_560m.json +++ b/launcher/tests/bloom_560m.json @@ -1,122 +1,142 @@ { + "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException", "details": { "finish_reason": "length", "generated_tokens": 20, + "seed": null, "prefill": [ { "id": 10264, - "logprob": null, - "text": "Test" + "text": "Test", + "logprob": null }, { "id": 8821, - "logprob": -11.894989, - "text": " request" + "text": " request", + "logprob": -11.894989 } ], - "seed": null, "tokens": [ { "id": 17, + "text": ".", "logprob": -1.8267672, - "text": "." + "special": false }, { "id": 1587, + "text": "get", "logprob": -2.4674969, - "text": "get" + "special": false }, { "id": 11, + "text": "(", "logprob": -1.906001, - "text": "(" + "special": false }, { "id": 5, + "text": "\"", "logprob": -1.2279545, - "text": "\"" + "special": false }, { "id": 4899, + "text": "action", "logprob": -4.170299, - "text": "action" + "special": false }, { "id": 5, + "text": "\"", "logprob": -0.32478866, - "text": "\"" + "special": false }, { "id": 12, + "text": ")", "logprob": -1.0773665, - "text": ")" + "special": false }, { "id": 30, + "text": ";", "logprob": -0.27640742, - "text": ";" + "special": false }, { "id": 837, + "text": "\n ", "logprob": -1.6970354, - "text": "\n " + "special": false }, { "id": 1320, + "text": " if", "logprob": -1.4495516, - "text": " if" + "special": false }, { "id": 375, + "text": " (", "logprob": -0.23609057, - "text": " (" + "special": false }, { "id": 4899, + "text": "action", "logprob": -1.1916996, - "text": "action" + "special": false }, { "id": 3535, + "text": " ==", "logprob": -0.8918753, - "text": " ==" + "special": false }, { "id": 5109, + "text": " null", "logprob": -0.3933342, - "text": " null" + "special": false }, { "id": 12, + "text": ")", "logprob": -0.43212673, - "text": ")" + "special": false }, { "id": 731, + "text": " {", "logprob": -0.17702064, - "text": " {" + "special": false }, { "id": 1260, + "text": "\n ", "logprob": -0.07027565, - "text": "\n " + "special": false }, { "id": 10519, + "text": " throw", "logprob": -1.3915029, - "text": " throw" + "special": false }, { "id": 2084, + "text": " new", "logprob": -0.04201372, - "text": " new" + "special": false }, { "id": 150858, + "text": " RuntimeException", "logprob": -1.7329919, - "text": " RuntimeException" + "special": false } ] - }, - "generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException" + } } \ No newline at end of file diff --git a/launcher/tests/integration_tests.rs b/launcher/tests/integration_tests.rs index b70b1628..0d2b6c74 100644 --- a/launcher/tests/integration_tests.rs +++ b/launcher/tests/integration_tests.rs @@ -14,6 +14,7 @@ pub struct Token { id: u32, text: String, logprob: Option, + special: bool, } #[derive(Deserialize)] @@ -136,6 +137,7 @@ fn compare_results(result: GeneratedText, expected: GeneratedText) { { assert_eq!(token.id, expected_token.id); assert_eq!(token.text, expected_token.text); + assert_eq!(token.special, expected_token.special); if let Some(logprob) = token.logprob { let expected_logprob = expected_token.logprob.unwrap(); assert_float_eq!(logprob, expected_logprob, abs <= 0.001); diff --git a/launcher/tests/mt0_base.json b/launcher/tests/mt0_base.json index cee3bc47..c06a2c26 100644 --- a/launcher/tests/mt0_base.json +++ b/launcher/tests/mt0_base.json @@ -1,117 +1,137 @@ { + "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test", "details": { "finish_reason": "length", "generated_tokens": 20, + "seed": null, "prefill": [ { "id": 0, - "logprob": null, - "text": "" + "text": "", + "logprob": null } ], - "seed": null, "tokens": [ { "id": 259, + "text": "", "logprob": -1.3656927, - "text": "" + "special": false }, { "id": 215100, + "text": "\"\"\"", "logprob": -2.6551573, - "text": "\"\"\"" + "special": false }, { "id": 46138, + "text": "Test", "logprob": -1.8059857, - "text": "Test" + "special": false }, { "id": 287, + "text": "the", "logprob": -1.2102449, - "text": "the" + "special": false }, { "id": 259, + "text": "", "logprob": -1.6057279, - "text": "" + "special": false }, { "id": 49076, + "text": "contents", "logprob": -3.6060903, - "text": "contents" + "special": false }, { "id": 304, + "text": "of", "logprob": -0.5270343, - "text": "of" + "special": false }, { "id": 287, + "text": "the", "logprob": -0.62522805, - "text": "the" + "special": false }, { "id": 259, + "text": "", "logprob": -1.4069618, - "text": "" + "special": false }, { "id": 49076, + "text": "contents", "logprob": -2.621994, - "text": "contents" + "special": false }, { "id": 304, + "text": "of", "logprob": -1.3172221, - "text": "of" + "special": false }, { "id": 287, + "text": "the", "logprob": -0.3501925, - "text": "the" + "special": false }, { "id": 259, + "text": "", "logprob": -0.7219573, - "text": "" + "special": false }, { "id": 49076, + "text": "contents", "logprob": -1.0494149, - "text": "contents" + "special": false }, { "id": 260, + "text": ".", "logprob": -1.0803378, - "text": "." + "special": false }, { "id": 259, + "text": "", "logprob": -0.32933083, - "text": "" + "special": false }, { "id": 215100, + "text": "\"\"\"", "logprob": -0.11268901, - "text": "\"\"\"" + "special": false }, { "id": 2978, + "text": "test", "logprob": -1.5846587, - "text": "test" + "special": false }, { "id": 290, + "text": "_", "logprob": -0.49796978, - "text": "_" + "special": false }, { "id": 4125, + "text": "test", "logprob": -2.0026445, - "text": "test" + "special": false } ] - }, - "generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test" + } } \ No newline at end of file diff --git a/proto/generate.proto b/proto/generate.proto index 0c4f9626..28a61362 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -108,8 +108,10 @@ message Generation { float token_logprob = 4; /// Text string token_text = 5; + /// Is it a special token + bool token_is_special = 6; /// Complete generated text - GeneratedText generated_text = 6; + GeneratedText generated_text = 7; } message PrefillRequest { diff --git a/router/src/infer.rs b/router/src/infer.rs index dc0df50a..598c2fcf 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -1,7 +1,7 @@ /// Batching and inference logic use crate::validation::{Validation, ValidationError}; -use crate::GenerateRequest; use crate::{Entry, Queue, Token}; +use crate::{GenerateRequest, PrefillToken}; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_client::{ @@ -138,7 +138,7 @@ impl Infer { .into_iter() .zip(tokens.logprobs.into_iter()) .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| Token { id, text, logprob }) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) .collect(); } // Push last token @@ -372,6 +372,7 @@ fn send_generations(generations: Vec, entries: &mut IntMap, + pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, diff --git a/router/src/lib.rs b/router/src/lib.rs index 8e3199dd..1f23bfd3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -86,6 +86,16 @@ pub(crate) struct GenerateRequest { pub parameters: GenerateParameters, } +#[derive(Debug, Serialize, ToSchema)] +pub struct PrefillToken { + #[schema(example = 0)] + id: u32, + #[schema(example = "test")] + text: String, + #[schema(nullable = true, example = -0.34)] + logprob: f32, +} + #[derive(Debug, Serialize, ToSchema)] pub struct Token { #[schema(example = 0)] @@ -94,6 +104,8 @@ pub struct Token { text: String, #[schema(nullable = true, example = -0.34)] logprob: f32, + #[schema(example = "false")] + special: bool, } #[derive(Serialize, ToSchema)] @@ -116,7 +128,7 @@ pub(crate) struct Details { pub generated_tokens: u32, #[schema(example = 42)] pub seed: Option, - pub prefill: Option>, + pub prefill: Option>, pub tokens: Option>, } diff --git a/router/src/server.rs b/router/src/server.rs index 6acbbffa..de96e397 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2,7 +2,7 @@ use crate::infer::{InferError, InferStreamResponse}; use crate::{ Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, - Infer, StreamDetails, StreamResponse, Token, Validation, + Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -255,11 +255,11 @@ async fn generate_stream( let time_per_token = inference_time / generated_text.generated_tokens; // Tracing metadata - span.record("total_time", format!("{:?}", total_time)); - span.record("validation_time", format!("{:?}", validation_time)); - span.record("queue_time", format!("{:?}", queue_time)); - span.record("inference_time", format!("{:?}", inference_time)); - span.record("time_per_token", format!("{:?}", time_per_token)); + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); span.record("seed", format!("{:?}", generated_text.seed)); tracing::info!(parent: &span, "Output: {}", generated_text.text); @@ -349,6 +349,7 @@ pub async fn run( schemas( GenerateRequest, GenerateParameters, + PrefillToken, Token, GenerateResponse, Details, diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index e109b83b..d15197d0 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -172,7 +172,9 @@ class CausalLMBatch(Batch): # and to remove unused allocated space left_offset = max_sequence_length - batch.max_sequence_length batch_left_offset = ( - batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset + batch.attention_mask.shape[1] + - batch.max_sequence_length + - batch.padding_right_offset ) attention_mask[ start_index:end_index, @@ -443,6 +445,7 @@ class CausalLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, + next_token_id_squeezed in self.all_special_ids, generated_text, ) diff --git a/server/text_generation/models/model.py b/server/text_generation/models/model.py index ef6a5682..52480526 100644 --- a/server/text_generation/models/model.py +++ b/server/text_generation/models/model.py @@ -12,6 +12,7 @@ B = TypeVar("B", bound=Batch) class Model(ABC): def __init__(self, tokenizer: PreTrainedTokenizerBase, device: torch.device): self.tokenizer = tokenizer + self.all_special_ids = set(tokenizer.all_special_ids) self.device = device @property diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 4813764b..3a4108ab 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -205,7 +205,8 @@ class Seq2SeqLMBatch(Batch): else: batch_left_offset = ( batch.decoder_attention_mask.shape[1] - - batch.max_decoder_input_length - batch.padding_right_offset + - batch.max_decoder_input_length + - batch.padding_right_offset ) decoder_attention_mask[ start_index:end_index, @@ -494,14 +495,10 @@ class Seq2SeqLM(Model): # Prefill if stopping_criteria.current_tokens == 1: - prefill_token_ids = decoder_input_ids[-new_decoder_input_length:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) prefill_tokens = PrefillTokens( - prefill_token_ids, [float("nan")], prefill_texts + [self.tokenizer.bos_token_id], + [float("nan")], + [self.tokenizer.bos_token], ) else: prefill_tokens = None @@ -512,6 +509,7 @@ class Seq2SeqLM(Model): next_token_id_squeezed, next_token_logprob, next_token_text, + next_token_id_squeezed in self.all_special_ids, generated_text, ) diff --git a/server/text_generation/models/types.py b/server/text_generation/models/types.py index d1117b80..a3fbd6e8 100644 --- a/server/text_generation/models/types.py +++ b/server/text_generation/models/types.py @@ -73,6 +73,7 @@ class Generation: token_id: int token_logprob: float token_text: str + token_is_special: bool generated_text: Optional[GeneratedText] def to_pb(self) -> generate_pb2.Generation: @@ -84,6 +85,7 @@ class Generation: token_id=self.token_id, token_logprob=self.token_logprob, token_text=self.token_text, + token_is_special=self.token_is_special, generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,