From 55bd4fed7da83a566dca08b0bb29dbc5929a90eb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 9 Mar 2023 15:30:54 +0100 Subject: [PATCH] feat(router): add best_of parameter (#117) --- docs/openapi.json | 97 +++++++++++++-- launcher/src/main.rs | 5 + router/src/infer.rs | 38 ++++++ router/src/lib.rs | 34 +++++- router/src/main.rs | 4 + router/src/server.rs | 251 +++++++++++++++++++++++---------------- router/src/validation.rs | 51 +++++++- 7 files changed, 365 insertions(+), 115 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 7ece1722..881b892c 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -210,13 +210,63 @@ }, "components": { "schemas": { + "BestOfSequence": { + "type": "object", + "required": [ + "generated_text", + "finish_reason", + "generated_tokens", + "prefill", + "tokens" + ], + "properties": { + "finish_reason": { + "$ref": "#/components/schemas/FinishReason" + }, + "generated_text": { + "type": "string", + "example": "test" + }, + "generated_tokens": { + "type": "integer", + "format": "int32", + "example": 1 + }, + "prefill": { + "type": "array", + "items": { + "$ref": "#/components/schemas/PrefillToken" + } + }, + "seed": { + "type": "integer", + "format": "int64", + "example": 42, + "nullable": true + }, + "tokens": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Token" + } + } + } + }, "Details": { "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "prefill", + "tokens" ], "properties": { + "best_of_sequences": { + "type": "array", + "items": { + "$ref": "#/components/schemas/BestOfSequence" + } + }, "finish_reason": { "$ref": "#/components/schemas/FinishReason" }, @@ -234,7 +284,8 @@ "seed": { "type": "integer", "format": "int64", - "example": 42 + "example": 42, + "nullable": true }, "tokens": { "type": "array", @@ -247,11 +298,15 @@ "ErrorResponse": { "type": "object", "required": [ - "error" + "error", + "error_type" ], "properties": { "error": { "type": "string" + }, + "error_type": { + "type": "string" } } }, @@ -266,6 +321,13 @@ "GenerateParameters": { "type": "object", "properties": { + "best_of": { + "type": "integer", + "default": "null", + "example": 1, + "nullable": true, + "exclusiveMinimum": 0.0 + }, "details": { "type": "boolean", "default": "true" @@ -292,12 +354,17 @@ }, "return_full_text": { "type": "boolean", - "default": "None", - "example": false + "default": "null", + "example": false, + "nullable": true }, "seed": { "type": "integer", - "format": "int64" + "format": "int64", + "default": "null", + "example": "null", + "nullable": true, + "exclusiveMinimum": 0.0 }, "stop": { "type": "array", @@ -334,6 +401,21 @@ "maximum": 1.0, "exclusiveMinimum": 0.0 }, + "truncate": { + "type": "integer", + "default": "null", + "example": "null", + "nullable": true + }, + "typical_p": { + "type": "number", + "format": "float", + "default": "null", + "example": 0.95, + "nullable": true, + "maximum": 1.0, + "exclusiveMinimum": 0.0 + }, "watermark": { "type": "boolean", "default": "false", @@ -414,7 +496,8 @@ "seed": { "type": "integer", "format": "int64", - "example": 42 + "example": 42, + "nullable": true } } }, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 1865cf90..80466fe6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -31,6 +31,8 @@ struct Args { quantize: bool, #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "1000", long, env)] @@ -86,6 +88,7 @@ fn main() -> ExitCode { num_shard, quantize, max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, @@ -363,6 +366,8 @@ fn main() -> ExitCode { "text-generation-router".to_string(), "--max-concurrent-requests".to_string(), max_concurrent_requests.to_string(), + "--max-best-of".to_string(), + max_best_of.to_string(), "--max-stop-sequences".to_string(), max_stop_sequences.to_string(), "--max-input-length".to_string(), diff --git a/router/src/infer.rs b/router/src/infer.rs index d0964f97..5955faec 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -2,6 +2,7 @@ use crate::validation::{Validation, ValidationError}; use crate::{Entry, Queue, Token}; use crate::{GenerateRequest, PrefillToken}; +use futures::future::try_join_all; use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_client::{ @@ -177,6 +178,43 @@ impl Infer { Err(err) } } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with + /// the highest log probability per token + #[instrument(skip(self))] + pub(crate) async fn generate_best_of( + &self, + request: GenerateRequest, + best_of: usize, + ) -> Result<(InferResponse, Vec), InferError> { + // validate best_of parameter separately + let best_of = self.validation.validate_best_of(best_of)?; + + // create multiple generate requests + let mut infer_responses: Vec = + try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; + + // get the sequence with the highest log probability per token + let mut max_index = 0; + let mut max_logprob: f32 = f32::MIN; + + for (i, response) in infer_responses.iter().enumerate() { + // mean logprobs of the generated tokens + let sequence_logprob = response + .tokens + .iter() + .map(|token| token.logprob) + .sum::() + / response.tokens.len() as f32; + + // set best sequence + if sequence_logprob > max_logprob { + max_index = i; + max_logprob = sequence_logprob; + } + } + let best_response = infer_responses.remove(max_index); + Ok((best_response, infer_responses)) + } } /// Batching logic diff --git a/router/src/lib.rs b/router/src/lib.rs index 9fcc5085..91b4417c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -12,6 +12,9 @@ use validation::Validation; #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { + #[serde(default)] + #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] + pub best_of: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, @@ -56,13 +59,13 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] - #[schema(default = "null", example = false)] + #[schema(nullable = true, default = "null", example = false)] pub return_full_text: Option, #[serde(default)] #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] - #[schema(default = "null", example = "null")] + #[schema(nullable = true, default = "null", example = "null")] pub truncate: Option, #[serde(default)] #[schema(default = "false", example = true)] @@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters { #[schema(default = "true")] pub details: bool, #[serde(default)] + #[schema( + exclusive_minimum = 0, + nullable = true, + default = "null", + example = "null" + )] pub seed: Option, } @@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 { fn default_parameters() -> GenerateParameters { GenerateParameters { + best_of: None, temperature: None, repetition_penalty: None, top_k: None, @@ -158,16 +168,32 @@ pub(crate) enum FinishReason { StopSequence, } +#[derive(Serialize, ToSchema)] +pub(crate) struct BestOfSequence { + #[schema(example = "test")] + pub generated_text: String, + #[schema(example = "length")] + pub finish_reason: FinishReason, + #[schema(example = 1)] + pub generated_tokens: u32, + #[schema(nullable = true, example = 42)] + pub seed: Option, + pub prefill: Vec, + pub tokens: Vec, +} + #[derive(Serialize, ToSchema)] pub(crate) struct Details { #[schema(example = "length")] pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, - #[schema(example = 42)] + #[schema(nullable = true, example = 42)] pub seed: Option, pub prefill: Vec, pub tokens: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of_sequences: Option>, } #[derive(Serialize, ToSchema)] @@ -184,7 +210,7 @@ pub(crate) struct StreamDetails { pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, - #[schema(example = 42)] + #[schema(nullable = true, example = 42)] pub seed: Option, } diff --git a/router/src/main.rs b/router/src/main.rs index a51d3168..2ccf66b3 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer}; struct Args { #[clap(default_value = "128", long, env)] max_concurrent_requests: usize, + #[clap(default_value = "2", long, env)] + max_best_of: usize, #[clap(default_value = "4", long, env)] max_stop_sequences: usize, #[clap(default_value = "1000", long, env)] @@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> { // Pattern match configuration let Args { max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, @@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> { server::run( compat_return_full_text, max_concurrent_requests, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/server.rs b/router/src/server.rs index ef10b7b1..3b63ec8a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1,9 +1,10 @@ /// HTTP Server logic -use crate::infer::{InferError, InferStreamResponse}; +use crate::infer::{InferError, InferResponse, InferStreamResponse}; +use crate::validation::ValidationError; use crate::{ - CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, - GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, - Validation, + BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, + GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, + StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -64,6 +65,7 @@ async fn health(infer: Extension) -> Result<(), (StatusCode, Json 1 => { + let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + (response, Some(best_of_responses)) + } + _ => (infer.generate(req.0).await?, None), + }; // Token details let details = match details { - true => Some(Details { - finish_reason: FinishReason::from(response.generated_text.finish_reason), - generated_tokens: response.generated_text.generated_tokens, - prefill: response.prefill, - tokens: response.tokens, - seed: response.generated_text.seed, - }), + true => { + // convert best_of_responses + let best_of_sequences = best_of_responses.map(|responses: Vec| { + responses + .into_iter() + .map(|response: InferResponse| { + // Add prompt if return_full_text + let mut output_text = response.generated_text.text; + if let Some(prompt) = &add_prompt { + output_text = prompt.clone() + &output_text; + } + + BestOfSequence { + generated_text: output_text, + finish_reason: FinishReason::from( + response.generated_text.finish_reason, + ), + generated_tokens: response.generated_text.generated_tokens, + prefill: response.prefill, + tokens: response.tokens, + seed: response.generated_text.seed, + } + }) + .collect() + }); + + Some(Details { + finish_reason: FinishReason::from(response.generated_text.finish_reason), + generated_tokens: response.generated_text.generated_tokens, + prefill: response.prefill, + tokens: response.tokens, + seed: response.generated_text.seed, + best_of_sequences, + }) + } false => None, }; @@ -279,107 +315,115 @@ async fn generate_stream( } let details = req.0.parameters.details; - match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { - Ok(mut response_stream) => { - // Server-Sent Event stream - while let Some(response) = response_stream.next().await { - match response { - Ok(response) => { - match response { - // Prefill is ignored - InferStreamResponse::Prefill(_) => {} - // Yield event for every new token - InferStreamResponse::Token(token) => { - // StreamResponse - let stream_token = StreamResponse { - token, - generated_text: None, - details: None, - }; + let best_of = req.0.parameters.best_of.unwrap_or(1); + if best_of == 1 { + match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { + Ok(mut response_stream) => { + // Server-Sent Event stream + while let Some(response) = response_stream.next().await { + match response { + Ok(response) => { + match response { + // Prefill is ignored + InferStreamResponse::Prefill(_) => {} + // Yield event for every new token + InferStreamResponse::Token(token) => { + // StreamResponse + let stream_token = StreamResponse { + token, + generated_text: None, + details: None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()) - } - // Yield event for last token and compute timings - InferStreamResponse::End { - token, - generated_text, - start, - queued, - } => { - // Token details - let details = match details { - true => Some(StreamDetails { - finish_reason: FinishReason::from(generated_text.finish_reason), - generated_tokens: generated_text.generated_tokens, - seed: generated_text.seed, - }), - false => None, - }; - - // Timings - let total_time = start_time.elapsed(); - let validation_time = queued - start_time; - let queue_time = start - queued; - let inference_time = Instant::now() - start; - let time_per_token = inference_time / generated_text.generated_tokens; - - // Tracing metadata - span.record("total_time", format!("{total_time:?}")); - span.record("validation_time", format!("{validation_time:?}")); - span.record("queue_time", format!("{queue_time:?}")); - span.record("inference_time", format!("{inference_time:?}")); - span.record("time_per_token", format!("{time_per_token:?}")); - span.record("seed", format!("{:?}", generated_text.seed)); - tracing::info!(parent: &span, "Output: {}", generated_text.text); - - // Metrics - metrics::increment_counter!("tgi_request_success"); - metrics::histogram!("tgi_request_duration", total_time); - metrics::histogram!("tgi_request_validation_duration", validation_time); - metrics::histogram!("tgi_request_queue_duration", queue_time); - metrics::histogram!("tgi_request_inference_duration", inference_time); - metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); - metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); - - // StreamResponse - end_reached = true; - - let mut output_text = generated_text.text; - if let Some(prompt) = add_prompt { - output_text = prompt + &output_text; + yield Ok(Event::default().json_data(stream_token).unwrap()) } - - let stream_token = StreamResponse { + // Yield event for last token and compute timings + InferStreamResponse::End { token, - generated_text: Some(output_text), - details - }; + generated_text, + start, + queued, + } => { + // Token details + let details = match details { + true => Some(StreamDetails { + finish_reason: FinishReason::from(generated_text.finish_reason), + generated_tokens: generated_text.generated_tokens, + seed: generated_text.seed, + }), + false => None, + }; - yield Ok(Event::default().json_data(stream_token).unwrap()); - break; + // Timings + let total_time = start_time.elapsed(); + let validation_time = queued - start_time; + let queue_time = start - queued; + let inference_time = Instant::now() - start; + let time_per_token = inference_time / generated_text.generated_tokens; + + // Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + span.record("time_per_token", format!("{time_per_token:?}")); + span.record("seed", format!("{:?}", generated_text.seed)); + tracing::info!(parent: &span, "Output: {}", generated_text.text); + + // Metrics + metrics::increment_counter!("tgi_request_success"); + metrics::histogram!("tgi_request_duration", total_time); + metrics::histogram!("tgi_request_validation_duration", validation_time); + metrics::histogram!("tgi_request_queue_duration", queue_time); + metrics::histogram!("tgi_request_inference_duration", inference_time); + metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token); + metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64); + + // StreamResponse + end_reached = true; + + let mut output_text = generated_text.text; + if let Some(prompt) = add_prompt { + output_text = prompt + &output_text; + } + + let stream_token = StreamResponse { + token, + generated_text: Some(output_text), + details + }; + + yield Ok(Event::default().json_data(stream_token).unwrap()); + break; + } } } - } - // yield error - Err(err) => { - error = true; - yield Ok(Event::from(err)); - break; + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); + break; + } } } + }, + // yield error + Err(err) => { + error = true; + yield Ok(Event::from(err)); } - }, - // yield error - Err(err) => { - error = true; + } + // Check if generation reached the end + // Skip if we already sent an error + if !end_reached && !error { + let err = InferError::IncompleteGeneration; + metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + tracing::error!("{err}"); yield Ok(Event::from(err)); } - } - // Check if generation reached the end - // Skip if we already sent an error - if !end_reached && !error { - let err = InferError::IncompleteGeneration; - metrics::increment_counter!("tgi_request_failure", "err" => "incomplete"); + } else { + let err = InferError::from(ValidationError::BestOfStream); + metrics::increment_counter!("tgi_request_failure", "err" => "validation"); tracing::error!("{err}"); yield Ok(Event::from(err)); } @@ -404,6 +448,7 @@ async fn metrics(prom_handle: Extension) -> String { pub async fn run( compat_return_full_text: bool, max_concurrent_requests: usize, + max_best_of: usize, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -430,6 +475,7 @@ pub async fn run( PrefillToken, Token, GenerateResponse, + BestOfSequence, Details, FinishReason, StreamResponse, @@ -454,6 +500,7 @@ pub async fn run( let validation = Validation::new( validation_workers, tokenizer, + max_best_of, max_stop_sequences, max_input_length, max_total_tokens, diff --git a/router/src/validation.rs b/router/src/validation.rs index 42af0169..cb8dd0a2 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,4 +1,4 @@ -use crate::validation::ValidationError::EmptyInput; +use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; use rand::rngs::ThreadRng; @@ -13,6 +13,9 @@ use tracing::{instrument, Span}; /// Validation #[derive(Debug, Clone)] pub struct Validation { + /// maximum value for the best_of parameter + #[allow(dead_code)] + max_best_of: usize, /// Channel to communicate with the background validation task sender: mpsc::UnboundedSender, } @@ -21,6 +24,7 @@ impl Validation { pub(crate) fn new( workers: usize, tokenizer: Tokenizer, + max_best_of: usize, max_stop_sequences: usize, max_input_length: usize, max_total_tokens: usize, @@ -39,6 +43,7 @@ impl Validation { )); Self { + max_best_of, sender: validation_sender, } } @@ -60,6 +65,20 @@ impl Validation { // Unwrap is safe here receiver.await.unwrap() } + + /// Validate the best_of parameter + #[instrument(skip_all)] + pub(crate) fn validate_best_of(&self, best_of: usize) -> Result { + if self.max_best_of == 1 && best_of != 1 { + return Err(ValidationError::BestOfDisabled); + } + + if best_of > self.max_best_of { + return Err(ValidationError::BestOf(self.max_best_of, best_of)); + } + + Ok(best_of) + } } /// Validation task @@ -150,6 +169,7 @@ fn validate( rng: &mut ThreadRng, ) -> Result { let GenerateParameters { + best_of, temperature, repetition_penalty, top_k, @@ -164,6 +184,18 @@ fn validate( .. } = request.parameters; + // sampling must be true when best_of > 1 + let best_of = best_of.unwrap_or(1); + let sampling = do_sample + || temperature.is_some() + || top_k.is_some() + || top_p.is_some() + || typical_p.is_some(); + + if best_of > 1 && !sampling { + return Err(BestOfSampling); + } + let temperature = temperature.unwrap_or(1.0); if temperature <= 0.0 { return Err(ValidationError::Temperature); @@ -217,7 +249,12 @@ fn validate( // If seed is None, assign a random one let seed = match seed { None => rng.gen(), - Some(seed) => seed, + Some(seed) => { + if best_of > 1 { + return Err(BestOfSeed); + } + seed + } }; // Check if inputs is empty @@ -307,6 +344,16 @@ pub(crate) struct ValidGenerateRequest { #[derive(Error, Debug)] pub enum ValidationError { + #[error("`best_of` must be > 0 and <= {0}. Given: {1}")] + BestOf(usize, usize), + #[error("`best_of` != 1 is not allowed for this endpoint")] + BestOfDisabled, + #[error("you must use sampling when `best_of` is > 1")] + BestOfSampling, + #[error("`seed` must not be set when `best_of` > 1")] + BestOfSeed, + #[error("`best_of` != 1 is not supported when streaming tokens")] + BestOfStream, #[error("`temperature` must be strictly positive")] Temperature, #[error("`repetition_penalty` must be strictly positive")]