diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index c72d31d3..67afa04e 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -37,6 +37,7 @@ pub(crate) async fn generation_task( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, parameters: NextTokenChooserParameters, @@ -48,7 +49,7 @@ pub(crate) async fn generation_task( // End task if a message is received on shutdown_receiver // _shutdown_guard_sender will be dropped once the task is finished tokio::select! { - res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => { + res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => { if let Err(err) = res { run_sender.send(Err(err)).await.unwrap_or(()); } @@ -64,6 +65,7 @@ async fn generate_runs( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, parameters: NextTokenChooserParameters, @@ -82,6 +84,7 @@ async fn generate_runs( b, decode_length, parameters.clone(), + top_n_tokens, &mut client, ) .await?; @@ -97,6 +100,7 @@ async fn generate_runs( b, decode_length, parameters.clone(), + top_n_tokens, &mut client, ) .await?; @@ -130,6 +134,7 @@ async fn prefill( batch_size: u32, decode_length: u32, parameters: NextTokenChooserParameters, + top_n_tokens: Option, client: &mut ShardedClient, ) -> Result<(Prefill, CachedBatch), ClientError> { // Create requests @@ -145,6 +150,7 @@ async fn prefill( stop_sequences: vec![], ignore_eos_token: true, // Will not stop even if a eos token is generated }), + top_n_tokens: top_n_tokens.unwrap_or(0), }) .collect(); diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs index fcad400c..433c6f67 100644 --- a/benchmark/src/lib.rs +++ b/benchmark/src/lib.rs @@ -22,6 +22,7 @@ pub async fn run( batch_size: Vec, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, temperature: Option, @@ -70,6 +71,7 @@ pub async fn run( batch_size.clone(), sequence_length, decode_length, + top_n_tokens, n_runs, warmups, parameters, @@ -130,6 +132,7 @@ pub async fn run( tokenizer_name, sequence_length, decode_length, + top_n_tokens, n_runs, warmups, temperature, diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index a7550060..97c8af1c 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -93,6 +93,11 @@ struct Args { /// decoding strategies, for full doc refer to the `text-generation-server` #[clap(long, env)] do_sample: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_n_tokens: Option, } fn main() -> Result<(), Box> { @@ -117,6 +122,7 @@ fn main() -> Result<(), Box> { watermark, do_sample, master_shard_uds_path, + top_n_tokens, } = args; let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]); @@ -173,6 +179,7 @@ fn main() -> Result<(), Box> { batch_size, sequence_length, decode_length, + top_n_tokens, runs, warmups, temperature, diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index 6b74bc36..9e36717b 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -7,6 +7,7 @@ pub(crate) fn parameters_table( tokenizer_name: String, sequence_length: u32, decode_length: u32, + top_n_tokens: Option, n_runs: usize, warmups: usize, temperature: Option, @@ -24,6 +25,7 @@ pub(crate) fn parameters_table( builder.push_record(["Model", &tokenizer_name]); builder.push_record(["Sequence Length", &sequence_length.to_string()]); builder.push_record(["Decode Length", &decode_length.to_string()]); + builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]); builder.push_record(["N Runs", &n_runs.to_string()]); builder.push_record(["Warmups", &warmups.to_string()]); builder.push_record(["Temperature", &format!("{temperature:?}")]); diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bf045d47..015613c2 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -75,6 +75,7 @@ class Client: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text @@ -113,6 +114,8 @@ class Client: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Response: generated response @@ -134,6 +137,7 @@ class Client: typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, + top_n_tokens=top_n_tokens ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -164,6 +168,7 @@ class Client: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_n_tokens: Optional[int] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -198,6 +203,8 @@ class Client: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Iterator[StreamResponse]: stream of generated tokens @@ -219,6 +226,7 @@ class Client: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) @@ -317,6 +325,7 @@ class AsyncClient: typical_p: Optional[float] = None, watermark: bool = False, decoder_input_details: bool = False, + top_n_tokens: Optional[int] = None, ) -> Response: """ Given a prompt, generate the following text asynchronously @@ -355,6 +364,8 @@ class AsyncClient: Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) decoder_input_details (`bool`): Return the decoder input token logprobs and ids + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: Response: generated response @@ -376,6 +387,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -404,6 +416,7 @@ class AsyncClient: truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, + top_n_tokens: Optional[int] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -438,6 +451,8 @@ class AsyncClient: See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + top_n_tokens (`int`): + Return the `n` most likely tokens at each step Returns: AsyncIterator[StreamResponse]: stream of generated tokens @@ -459,6 +474,7 @@ class AsyncClient: truncate=truncate, typical_p=typical_p, watermark=watermark, + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=True, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 548f0b63..38f75253 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -39,6 +39,8 @@ class Parameters(BaseModel): details: bool = False # Get decoder input token logprobs and ids decoder_input_details: bool = False + # Return the N most likely tokens at each step + top_n_tokens: Optional[int] @validator("best_of") def valid_best_of(cls, field_value, values): @@ -101,6 +103,12 @@ class Parameters(BaseModel): raise ValidationError("`typical_p` must be > 0.0 and < 1.0") return v + @validator("top_n_tokens") + def valid_top_n_tokens(cls, v): + if v is not None and v <= 0: + raise ValidationError("`top_n_tokens` must be strictly positive") + return v + class Request(BaseModel): # Prompt @@ -125,9 +133,7 @@ class Request(BaseModel): and parameters.best_of > 1 and field_value ): - raise ValidationError( - "`best_of` != 1 is not supported when `stream` == True" - ) + raise ValidationError("`best_of` != 1 is not supported when `stream` == True") return field_value @@ -179,6 +185,8 @@ class BestOfSequence(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] # `generate` details @@ -193,6 +201,8 @@ class Details(BaseModel): prefill: List[InputToken] # Generated tokens tokens: List[Token] + # Most likely tokens + top_tokens: Optional[List[List[Token]]] # Additional sequences when using the `best_of` parameter best_of_sequences: Optional[List[BestOfSequence]] @@ -219,6 +229,8 @@ class StreamDetails(BaseModel): class StreamResponse(BaseModel): # Generated token token: Token + # Most likely tokens + top_tokens: Optional[List[Token]] # Complete generated text # Only available when the generation is finished generated_text: Optional[str] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 75762712..cbb6f25d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -159,6 +159,14 @@ struct Args { #[clap(default_value = "4", long, env)] max_stop_sequences: usize, + /// This is the maximum allowed value for clients to set `top_n_tokens`. + /// `top_n_tokens is used to return information about the the `n` most likely + /// tokens at each generation step, instead of just the sampled token. This + /// information can be used for downstream tasks like for classification or + /// ranking. + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, + /// This is the maximum allowed input length (expressed in number of tokens) /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. @@ -929,6 +937,8 @@ fn spawn_webserver( args.max_best_of.to_string(), "--max-stop-sequences".to_string(), args.max_stop_sequences.to_string(), + "--max-top-n-tokens".to_string(), + args.max_top_n_tokens.to_string(), "--max-input-length".to_string(), args.max_input_length.to_string(), "--max-total-tokens".to_string(), diff --git a/proto/generate.proto b/proto/generate.proto index 57d79bca..3f607dc5 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -91,6 +91,8 @@ message Request { StoppingCriteriaParameters stopping_parameters = 5; /// Return prefill logprobs bool prefill_logprobs = 6; + /// Return most likely n tokens + uint32 top_n_tokens = 7; } message Batch { @@ -141,6 +143,17 @@ message PrefillTokens { repeated string texts = 3; } +message TopTokens { + /// Top Token IDs + repeated uint32 ids = 1; + /// Top Logprobs + repeated float logprobs = 2; + /// Top Token Texts + repeated string texts = 3; + /// If the tokens are special + repeated bool is_special = 6; +} + message Generation { /// Request ID uint64 request_id = 1; @@ -156,6 +169,8 @@ message Generation { bool token_is_special = 6; /// Complete generated text optional GeneratedText generated_text = 7; + /// Top tokens + TopTokens top_tokens = 8; } message FilterBatchRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 7753f307..d427d3a4 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -131,6 +131,7 @@ impl Client { ignore_eos_token: false, }), prefill_logprobs: true, + top_n_tokens: 20, }); n_tokens += max_input_length; } diff --git a/router/src/health.rs b/router/src/health.rs index a3cacdcd..ab290fc1 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -50,6 +50,7 @@ impl Health { stop_sequences: vec![], ignore_eos_token: false, }), + top_n_tokens: 0, }; let batch = Batch { id: BATCH_ID, diff --git a/router/src/infer.rs b/router/src/infer.rs index 188ddc64..67b5bde2 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -138,12 +138,15 @@ impl Infer { &self, request: GenerateRequest, ) -> Result { + let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); + // Create stream and keep semaphore permit as long as generate lives let (_permit, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); let mut result_tokens = Vec::new(); + let mut result_top_tokens = Vec::new(); let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; @@ -164,7 +167,10 @@ impl Infer { .collect(); } // Push last token - InferStreamResponse::Token(token) => result_tokens.push(token), + InferStreamResponse::Intermediate { token, top_tokens } => { + result_tokens.push(token); + result_top_tokens.push(top_tokens); + } // Final message // Set return values InferStreamResponse::End { @@ -172,8 +178,10 @@ impl Infer { generated_text, start, queued, + top_tokens, } => { result_tokens.push(token); + result_top_tokens.push(top_tokens); result_generated_text = Some(generated_text); result_start = Some(start); result_queued = Some(queued) @@ -191,6 +199,11 @@ impl Infer { generated_text, queued, start, + top_tokens: if use_top_tokens { + result_top_tokens + } else { + Vec::new() + }, }) } else { let err = InferError::IncompleteGeneration; @@ -520,6 +533,26 @@ fn send_responses( special: generation.token_is_special, }; + // generation.top_tokens + + let mut top_tokens = Vec::new(); + if let Some(top_tokens_) = generation.top_tokens { + top_tokens.extend( + top_tokens_ + .ids + .into_iter() + .zip(top_tokens_.logprobs.into_iter()) + .zip(top_tokens_.texts.into_iter()) + .zip(top_tokens_.is_special.into_iter()) + .map(|(((id, logprob), text), special)| Token { + id, + text, + logprob, + special, + }), + ) + } + if let Some(generated_text) = generation.generated_text { // Generation has ended stopped = true; @@ -527,6 +560,7 @@ fn send_responses( entry.response_tx.send_timeout( Ok(InferStreamResponse::End { token, + top_tokens, generated_text, queued: entry.queue_time, start: entry.batch_time.unwrap(), @@ -536,7 +570,7 @@ fn send_responses( } else { // Send message entry.response_tx.send_timeout( - Ok(InferStreamResponse::Token(token)), + Ok(InferStreamResponse::Intermediate { token, top_tokens }), Duration::from_millis(10), )?; } @@ -566,10 +600,14 @@ pub(crate) enum InferStreamResponse { // Optional first message Prefill(PrefillTokens), // Intermediate messages - Token(Token), + Intermediate { + token: Token, + top_tokens: Vec, + }, // Last message End { token: Token, + top_tokens: Vec, generated_text: GeneratedText, start: Instant, queued: Instant, @@ -583,6 +621,7 @@ pub(crate) struct InferResponse { pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, pub(crate) start: Instant, + pub(crate) top_tokens: Vec>, } #[derive(Debug, Error)] diff --git a/router/src/lib.rs b/router/src/lib.rs index 7dff7a11..76e70bb7 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -135,6 +135,9 @@ pub(crate) struct GenerateParameters { example = "null" )] pub seed: Option, + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] + pub top_n_tokens: Option, } fn default_max_new_tokens() -> u32 { @@ -158,6 +161,7 @@ fn default_parameters() -> GenerateParameters { details: false, decoder_input_details: false, seed: None, + top_n_tokens: None, } } @@ -235,6 +239,8 @@ pub(crate) struct BestOfSequence { pub seed: Option, pub prefill: Vec, pub tokens: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -249,6 +255,8 @@ pub(crate) struct Details { pub tokens: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub best_of_sequences: Option>, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec>, } #[derive(Serialize, ToSchema)] @@ -272,6 +280,8 @@ pub(crate) struct StreamDetails { #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub top_tokens: Vec, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] diff --git a/router/src/main.rs b/router/src/main.rs index 484643cb..4903c066 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -29,6 +29,8 @@ struct Args { max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, + #[clap(default_value = "5", long, env)] + max_top_n_tokens: u32, #[clap(default_value = "1024", long, env)] max_input_length: usize, #[clap(default_value = "2048", long, env)] @@ -75,6 +77,7 @@ fn main() -> Result<(), RouterError> { max_concurrent_requests, max_best_of, max_stop_sequences, + max_top_n_tokens, max_input_length, max_total_tokens, waiting_served_ratio, @@ -259,6 +262,7 @@ fn main() -> Result<(), RouterError> { max_concurrent_requests, max_best_of, max_stop_sequences, + max_top_n_tokens, max_input_length, max_total_tokens, waiting_served_ratio, diff --git a/router/src/queue.rs b/router/src/queue.rs index 2d8d6d1c..e97a168e 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -235,6 +235,7 @@ impl State { truncate: entry.request.truncate, parameters: Some(entry.request.parameters.clone()), stopping_parameters: Some(entry.request.stopping_parameters.clone()), + top_n_tokens: entry.request.top_n_tokens, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -328,6 +329,7 @@ mod tests { max_new_tokens: 1, stop_sequences: vec![], }, + top_n_tokens: 0, }, response_tx, span: info_span!("entry"), diff --git a/router/src/server.rs b/router/src/server.rs index e609821c..91164098 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -158,7 +158,7 @@ async fn generate( add_prompt = Some(req.inputs.clone()); } - let details = req.parameters.details || req.parameters.decoder_input_details; + let details: bool = req.parameters.details || req.parameters.decoder_input_details; // Inference let (response, best_of_responses) = match req.parameters.best_of { @@ -191,6 +191,7 @@ async fn generate( generated_tokens: response.generated_text.generated_tokens, prefill: response.prefill, tokens: response.tokens, + top_tokens: response.top_tokens, seed: response.generated_text.seed, } }) @@ -204,6 +205,7 @@ async fn generate( tokens: response.tokens, seed: response.generated_text.seed, best_of_sequences, + top_tokens: response.top_tokens, }) } false => None, @@ -385,12 +387,16 @@ async fn generate_stream( // Prefill is ignored InferStreamResponse::Prefill(_) => {} // Yield event for every new token - InferStreamResponse::Token(token) => { + InferStreamResponse::Intermediate{ + token, + top_tokens, + } => { tracing::debug!(parent: &span, "Token: {:?}", token); // StreamResponse let stream_token = StreamResponse { token, + top_tokens: top_tokens, generated_text: None, details: None, }; @@ -403,6 +409,7 @@ async fn generate_stream( generated_text, start, queued, + top_tokens, } => { // Token details let details = match details { @@ -451,6 +458,7 @@ async fn generate_stream( let stream_token = StreamResponse { token, + top_tokens: top_tokens, generated_text: Some(output_text), details }; @@ -509,6 +517,7 @@ pub async fn run( max_concurrent_requests: usize, max_best_of: usize, max_stop_sequences: usize, + max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, waiting_served_ratio: f32, @@ -571,6 +580,7 @@ pub async fn run( tokenizer, max_best_of, max_stop_sequences, + max_top_n_tokens, max_input_length, max_total_tokens, ); diff --git a/router/src/validation.rs b/router/src/validation.rs index f967361f..6c67f0ff 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -15,6 +15,7 @@ pub struct Validation { /// Validation parameters max_best_of: usize, max_stop_sequences: usize, + max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, /// Channel to communicate with the background tokenization task @@ -27,6 +28,7 @@ impl Validation { tokenizer: Option, max_best_of: usize, max_stop_sequences: usize, + max_top_n_tokens: u32, max_input_length: usize, max_total_tokens: usize, ) -> Self { @@ -54,6 +56,7 @@ impl Validation { max_best_of, sender, max_stop_sequences, + max_top_n_tokens, max_input_length, max_total_tokens, } @@ -142,6 +145,7 @@ impl Validation { seed, watermark, decoder_input_details, + top_n_tokens, .. } = request.parameters; @@ -218,6 +222,15 @@ impl Validation { } }; + let top_n_tokens = top_n_tokens + .map(|value| { + if value > self.max_top_n_tokens { + return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value)); + } + Ok(value) + }) + .unwrap_or(Ok(0))?; + // Check if inputs is empty if request.inputs.is_empty() { return Err(EmptyInput); @@ -263,6 +276,7 @@ impl Validation { truncate: truncate.unwrap_or(self.max_input_length) as u32, parameters, stopping_parameters, + top_n_tokens: top_n_tokens, }) } @@ -336,6 +350,7 @@ pub(crate) struct ValidGenerateRequest { pub decoder_input_details: bool, pub parameters: NextTokenChooserParameters, pub stopping_parameters: StoppingCriteriaParameters, + pub top_n_tokens: u32, } #[derive(Error, Debug)] @@ -350,6 +365,10 @@ pub enum ValidationError { BestOfSeed, #[error("`best_of` != 1 is not supported when streaming tokens")] BestOfStream, + #[error("`top_n_tokens` must be >= 0 and <= {0}. Given: {1}")] + TopNTokens(u32, u32), + #[error("`top_n_tokens` != 0 is not allowed for this endpoint")] + TopNTokensDisabled, #[error("`decoder_input_details` == true is not supported when streaming tokens")] PrefillDetailsStream, #[error("`temperature` must be strictly positive")] @@ -391,14 +410,16 @@ mod tests { let tokenizer = None; let max_best_of = 2; let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, tokenizer, max_best_of, max_stop_sequence, + max_top_n_tokens, max_input_length, max_total_tokens, ); @@ -418,14 +439,16 @@ mod tests { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, tokenizer, max_best_of, max_stop_sequence, + max_top_n_tokens, max_input_length, max_total_tokens, ); @@ -435,7 +458,7 @@ mod tests { .validate_input("Hello".to_string(), None, max_new_tokens) .await { - Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), + Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), _ => panic!("Unexpected not max new tokens"), } } @@ -445,14 +468,16 @@ mod tests { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, tokenizer, max_best_of, max_stop_sequence, + max_top_n_tokens, max_input_length, max_total_tokens, ); @@ -477,14 +502,16 @@ mod tests { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; let max_stop_sequence = 3; - let max_input_length = 4; - let max_total_tokens = 5; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; let workers = 1; let validation = Validation::new( workers, tokenizer, max_best_of, max_stop_sequence, + max_top_n_tokens, max_input_length, max_total_tokens, ); @@ -531,4 +558,75 @@ mod tests { // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. assert_eq!(valid_request.parameters.top_p, 1.0); } + + #[tokio::test] + async fn test_validation_top_n_tokens() { + let tokenizer = Some(get_tokenizer().await); + let max_best_of = 2; + let max_stop_sequences = 3; + let max_top_n_tokens = 4; + let max_input_length = 5; + let max_total_tokens = 6; + let workers = 1; + let validation = Validation::new( + workers, + tokenizer, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_length, + max_total_tokens, + ); + match validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_n_tokens: Some(5), + ..default_parameters() + }, + }) + .await + { + Err(ValidationError::TopNTokens(4, 5)) => (), + _ => panic!("Unexpected top_n_tokens"), + } + + validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_n_tokens: Some(4), + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + .unwrap(); + + validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_n_tokens: Some(0), + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + .unwrap(); + + let valid_request = validation + .validate(GenerateRequest { + inputs: "Hello".to_string(), + parameters: GenerateParameters { + top_n_tokens: None, + max_new_tokens: 1, + ..default_parameters() + }, + }) + .await + .unwrap(); + + assert_eq!(valid_request.top_n_tokens, 0); + } } diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index da0006e4..4187ff25 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,7 +1,9 @@ +import torch from text_generation_server.utils.tokens import ( StopSequenceCriteria, StoppingCriteria, FinishReason, + batch_top_tokens, ) @@ -42,3 +44,22 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + +def test_batch_top_tokens(): + top_n_tokens = [0, 2, 3, 4, 5] + top_n_tokens_tensor = torch.tensor(top_n_tokens) + inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) + + topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs) + + assert topn_tok_ids[0] == [] + assert topn_tok_ids[1] == [0, 3] + assert topn_tok_ids[2] == [0, 3, 1, 4] + assert topn_tok_ids[3] == [0, 3, 1, 4] + assert topn_tok_ids[4] == [0, 3, 1, 4, 2] + + assert topn_tok_logprobs[0] == [] + assert topn_tok_logprobs[1] == [-1, -2] + assert topn_tok_logprobs[2] == [-1, -2, -3, -3] + assert topn_tok_logprobs[3] == [-1, -2, -3, -3] + assert topn_tok_logprobs[4] == [-1, -2, -3, -3, -4] diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cbdf4808..4e338263 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.tokens import batch_top_tokens import torch import inspect @@ -12,6 +13,7 @@ from text_generation_server.models.types import ( PrefillTokens, Generation, GeneratedText, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -42,6 +44,8 @@ class CausalLMBatch(Batch): # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Metadata used for padding max_input_length: int @@ -72,6 +76,7 @@ class CausalLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] prefix_offsets = [] read_offsets = [] requests_idx_mapping = {} @@ -88,6 +93,7 @@ class CausalLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -121,6 +127,9 @@ class CausalLMBatch(Batch): position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -138,6 +147,8 @@ class CausalLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), padding_right_offset=padding_right_offset, max_tokens=max_tokens, @@ -163,6 +174,7 @@ class CausalLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] total_remaining_decode_tokens = 0 new_padding_right_offset = 0 @@ -184,6 +196,7 @@ class CausalLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -223,6 +236,7 @@ class CausalLMBatch(Batch): layer[1] = past_values[keep_indices, :, -past_kv_length:, :] del past_values + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens self.requests = requests @@ -235,6 +249,8 @@ class CausalLMBatch(Batch): self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor self.max_input_length = max_input_length self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens @@ -262,6 +278,7 @@ class CausalLMBatch(Batch): all_input_ids = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_tokens = 0 # Batch tensors @@ -269,6 +286,7 @@ class CausalLMBatch(Batch): attention_mask = None position_ids = None past_key_values = [] + top_n_tokens_tensor = None # Used for slicing correctly inside the tensors # Equivalent to a cumsum on batch sizes @@ -281,6 +299,7 @@ class CausalLMBatch(Batch): all_input_ids.extend(batch.all_input_ids) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -310,6 +329,12 @@ class CausalLMBatch(Batch): (total_batch_size, max_input_length + padding_right_offset), ) + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # We need to slice the attention mask to remove padding from previous steps # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length @@ -438,6 +463,8 @@ class CausalLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length, padding_right_offset=padding_right_offset, keys_head_dim_last=batches[0].keys_head_dim_last, @@ -549,6 +576,12 @@ class CausalLM(Model): generations: List[Generation] = [] stopped = True + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.softmax(logits[:, -1], -1), + ) + # Zipped iterator iterator = zip( batch.requests, @@ -559,6 +592,9 @@ class CausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -571,6 +607,9 @@ class CausalLM(Model): next_token_chooser, stopping_criteria, all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -637,6 +676,24 @@ class CausalLM(Model): else: prefill_tokens = None + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens, @@ -645,6 +702,7 @@ class CausalLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7de51358..d6af07f4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,5 +1,6 @@ import math import itertools +from text_generation_server.utils.tokens import batch_top_tokens import torch import torch.distributed @@ -16,6 +17,7 @@ from text_generation_server.models.types import ( PrefillTokens, Generation, GeneratedText, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser @@ -165,6 +167,8 @@ class FlashCausalLMBatch(Batch): # Generation helpers next_token_chooser: HeterogeneousNextTokenChooser stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Number of blocks in this batch blocks: int @@ -217,6 +221,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters = [] stopping_criterias = [] + top_n_tokens = [] # Cumulative length cumulative_length = 0 @@ -259,6 +264,7 @@ class FlashCausalLMBatch(Batch): ) max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) # Paged attention # Remove one as the first token des not have a past @@ -352,6 +358,9 @@ class FlashCausalLMBatch(Batch): prefill_next_token_indices = torch.tensor( prefill_next_token_indices, dtype=torch.int64, device=device ) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) return cls( batch_id=pb.id, @@ -378,6 +387,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -417,6 +428,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] stopping_criterias = [] + top_n_tokens = [] blocks = 0 max_blocks = 0 @@ -443,6 +455,8 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -487,6 +501,7 @@ class FlashCausalLMBatch(Batch): input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) + top_n_tokens_tensor = self.top_n_tokens_tensor[indices] start_slots = torch.tensor(start_slots, dtype=torch.int64) @@ -518,6 +533,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -566,6 +583,9 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) start_slots = [] block_tables = [] @@ -577,6 +597,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters = [] stopping_criterias = [] + top_n_tokens = [] # Cumulative length cumulative_batch_size = 0 @@ -602,6 +623,7 @@ class FlashCausalLMBatch(Batch): position_ids[start_index:end_index] = batch.position_ids slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor slots[slots_start_index:slots_end_index] = batch.slots all_input_ids_tensor[ @@ -624,6 +646,8 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + # Update cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) @@ -666,6 +690,8 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, ) @@ -831,10 +857,14 @@ class FlashCausalLM(Model): else: next_token_logits = out - next_input_ids, next_token_logprobs = batch.next_token_chooser( + next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs + ) + if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -931,8 +961,11 @@ class FlashCausalLM(Model): batch.all_input_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, + batch.top_n_tokens, next_token_ids, next_token_logprobs, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -945,8 +978,11 @@ class FlashCausalLM(Model): all_input_ids, do_sample, seed, + top_n_tokens, next_token_id, next_token_logprob, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens all_input_ids.append(next_token_id) @@ -1005,6 +1041,24 @@ class FlashCausalLM(Model): else: prefill_tokens = None + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens, @@ -1013,6 +1067,7 @@ class FlashCausalLM(Model): next_token_text, next_token_id in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index faad63ba..2dac87bc 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -763,6 +763,8 @@ class IdeficsCausalLM(Model): else: prefill_tokens = None + top_tokens=None + generation = Generation( request.id, prefill_tokens, @@ -771,6 +773,7 @@ class IdeficsCausalLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 9e5c21d1..361453fb 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,3 +1,4 @@ +from text_generation_server.utils.tokens import batch_top_tokens import torch from dataclasses import dataclass @@ -11,6 +12,7 @@ from text_generation_server.models.types import ( Batch, Generation, PrefillTokens, + TopTokens, ) from text_generation_server.pb import generate_pb2 from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling @@ -48,6 +50,8 @@ class Seq2SeqLMBatch(Batch): # Generation helpers next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor # Metadata used for padding max_input_length: int @@ -78,7 +82,7 @@ class Seq2SeqLMBatch(Batch): inputs = [] next_token_choosers = [] stopping_criterias = [] - + top_n_tokens = [] decoder_input_lengths = [] prefix_offsets = [] read_offsets = [] @@ -97,6 +101,7 @@ class Seq2SeqLMBatch(Batch): r.stopping_parameters, tokenizer ) stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) max_truncation = max(max_truncation, r.truncate) max_decode_tokens += stopping_criteria.max_new_tokens padding_right_offset = max( @@ -126,6 +131,9 @@ class Seq2SeqLMBatch(Batch): prefix_offsets.append(0) read_offsets.append(1) all_decoder_input_ids = decoder_input_ids.view(-1).split(1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -146,6 +154,8 @@ class Seq2SeqLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length.item(), max_decoder_input_length=1, padding_right_offset=padding_right_offset, @@ -173,6 +183,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_input_length = 0 max_decoder_input_length = 0 @@ -204,6 +215,7 @@ class Seq2SeqLMBatch(Batch): next_token_choosers.append(self.next_token_choosers[idx]) stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) remaining_decode_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) @@ -239,6 +251,7 @@ class Seq2SeqLMBatch(Batch): layer[2] = layer[2][keep_indices, :, -max_input_length:] layer[3] = layer[3][keep_indices, :, -max_input_length:] + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] max_tokens = ( len(request_ids) * (max_input_length + max_decoder_input_length) + remaining_decode_tokens @@ -254,6 +267,8 @@ class Seq2SeqLMBatch(Batch): self.read_offsets = read_offsets self.next_token_choosers = next_token_choosers self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor self.max_input_length = max_input_length self.max_decoder_input_length = max_decoder_input_length self.padding_right_offset = padding_right_offset @@ -289,6 +304,7 @@ class Seq2SeqLMBatch(Batch): read_offsets = [] next_token_choosers = [] stopping_criterias = [] + top_n_tokens = [] max_tokens = 0 # Batch tensors @@ -296,6 +312,7 @@ class Seq2SeqLMBatch(Batch): decoder_input_ids = None decoder_attention_mask = None encoder_last_hidden_state = None + top_n_tokens_tensor = None past_key_values = [] # Used for slicing correctly inside the tensors @@ -312,6 +329,7 @@ class Seq2SeqLMBatch(Batch): read_offsets.extend(batch.read_offsets) next_token_choosers.extend(batch.next_token_choosers) stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) if i == 0: requests_idx_mapping = batch.requests_idx_mapping @@ -384,6 +402,12 @@ class Seq2SeqLMBatch(Batch): ), ) + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + # Copy to correct indices encoder_last_hidden_state[ start_index:end_index, -batch.max_input_length :, : @@ -488,6 +512,8 @@ class Seq2SeqLMBatch(Batch): read_offsets=read_offsets, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, max_input_length=max_input_length, max_decoder_input_length=max_decoder_input_length, padding_right_offset=padding_right_offset, @@ -613,6 +639,12 @@ class Seq2SeqLM(Model): batch.past_key_values, ) + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.softmax(logits[:, -1], -1), + ) + # Finished requests generations: List[Generation] = [] stopped = True @@ -628,6 +660,9 @@ class Seq2SeqLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_decoder_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, ) # For each member of the batch @@ -641,6 +676,9 @@ class Seq2SeqLM(Model): next_token_chooser, stopping_criteria, all_decoder_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -698,6 +736,24 @@ class Seq2SeqLM(Model): else: prefill_tokens = None + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = TopTokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + generation = Generation( request.id, prefill_tokens, @@ -706,6 +762,7 @@ class Seq2SeqLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, + top_tokens, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 28ca8147..0e27680d 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -1,3 +1,4 @@ +from functools import total_ordering import torch from abc import ABC, abstractmethod @@ -71,6 +72,25 @@ class PrefillTokens: return len(self.token_ids) +@dataclass +class TopTokens: + token_ids: List[int] + logprobs: List[float] + texts: List[str] + is_special: List[bool] + + def to_pb(self) -> generate_pb2.TopTokens: + return generate_pb2.TopTokens( + ids=self.token_ids, + logprobs=self.logprobs, + texts=self.texts, + is_special=self.is_special, + ) + + def __len__(self): + return len(self.token_ids) + + @dataclass class Generation: request_id: int @@ -80,6 +100,8 @@ class Generation: token_text: str token_is_special: bool generated_text: Optional[GeneratedText] + # Optional for now, since it's not yet supported for every model. + top_tokens: Optional[TopTokens] def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -94,4 +116,5 @@ class Generation: generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, + top_tokens=self.top_tokens.to_pb() if self.top_tokens is not None else None, ) diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index b83af591..69177d56 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -1,24 +1,20 @@ import re +from typing import Callable, List, Optional, Tuple + import torch - -from transformers import ( - RepetitionPenaltyLogitsProcessor, - PreTrainedTokenizerBase, -) -from typing import List, Tuple, Optional - from text_generation_server.pb import generate_pb2 from text_generation_server.pb.generate_pb2 import FinishReason -from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.logits_process import ( - static_warper, + HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousTemperatureLogitsWarper, HeterogeneousTopKLogitsWarper, HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, - HeterogeneousProcessorWrapper, + static_warper, ) +from text_generation_server.utils.watermark import WatermarkLogitsProcessor +from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor class NextTokenChooser: @@ -229,11 +225,10 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) next_ids = self.choice(scores) - next_logprobs = torch.gather( - torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1) - ).view(-1) + logprobs = torch.log_softmax(scores, -1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - return next_ids, next_logprobs + return next_ids, next_logprobs, logprobs def filter(self, indices): if self.watermark_processor is not None: @@ -339,3 +334,50 @@ class HeterogeneousSampling: self.greedy_indices = new_greedy_indices self.sampling_mapping = new_sampling_mapping return self + + +def batch_top_tokens( + top_n_tokens: list[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor +) -> Tuple[List[List[int]], List[List[float]]]: + """Find the top n most likely tokens for a batch of generations. + + When multiple tokens have equal probabilities and they don't all fit, the + remaining tokens are also returned. + """ + max_top_n = max(top_n_tokens) + # Early exit when top_n_tokens is not used + if max_top_n == 0: + return [[]] * len(top_n_tokens), [[]] * len(top_n_tokens) + + # Ensure top_n doesn't exceed vocab size + top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens] + + # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 + # Sorted topk is faster than torch.sort() since we only need a small subset + sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=1, sorted=True).values + nth_highest = torch.gather( + sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1) + ) + nth_highest[nth_highest == -float("inf")] = torch.finfo(logprobs.dtype).min + + # Find the new "fuzzy" top n values + top_n_indices = (logprobs >= nth_highest).nonzero() + _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) + + # Take a new topk for these new max n values + top_k = torch.topk(logprobs, k=top_n_ishes.max(), dim=1, sorted=True) + + top_n_ishes = top_n_ishes.tolist() + top_indices = top_k.indices.tolist() + top_values = top_k.values.tolist() + + return ( + [ + idxs[:n] if req_n > 0 else [] + for idxs, n, req_n in zip(top_indices, top_n_ishes, top_n_tokens) + ], + [ + vals[:n] if req_n > 0 else [] + for vals, n, req_n in zip(top_values, top_n_ishes, top_n_tokens) + ], + )