feat(router): add best_of parameter (#117)
This commit is contained in:
parent
e8bfe199ba
commit
55bd4fed7d
|
@ -210,13 +210,63 @@
|
||||||
},
|
},
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
|
"BestOfSequence": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"generated_text",
|
||||||
|
"finish_reason",
|
||||||
|
"generated_tokens",
|
||||||
|
"prefill",
|
||||||
|
"tokens"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"finish_reason": {
|
||||||
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
|
},
|
||||||
|
"generated_text": {
|
||||||
|
"type": "string",
|
||||||
|
"example": "test"
|
||||||
|
},
|
||||||
|
"generated_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int32",
|
||||||
|
"example": 1
|
||||||
|
},
|
||||||
|
"prefill": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/PrefillToken"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"seed": {
|
||||||
|
"type": "integer",
|
||||||
|
"format": "int64",
|
||||||
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"tokens": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"Details": {
|
"Details": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"finish_reason",
|
"finish_reason",
|
||||||
"generated_tokens"
|
"generated_tokens",
|
||||||
|
"prefill",
|
||||||
|
"tokens"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"best_of_sequences": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/BestOfSequence"
|
||||||
|
}
|
||||||
|
},
|
||||||
"finish_reason": {
|
"finish_reason": {
|
||||||
"$ref": "#/components/schemas/FinishReason"
|
"$ref": "#/components/schemas/FinishReason"
|
||||||
},
|
},
|
||||||
|
@ -234,7 +284,8 @@
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"tokens": {
|
"tokens": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -247,11 +298,15 @@
|
||||||
"ErrorResponse": {
|
"ErrorResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
"error"
|
"error",
|
||||||
|
"error_type"
|
||||||
],
|
],
|
||||||
"properties": {
|
"properties": {
|
||||||
"error": {
|
"error": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
},
|
||||||
|
"error_type": {
|
||||||
|
"type": "string"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -266,6 +321,13 @@
|
||||||
"GenerateParameters": {
|
"GenerateParameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"best_of": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": "null",
|
||||||
|
"example": 1,
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
"details": {
|
"details": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": "true"
|
"default": "true"
|
||||||
|
@ -292,12 +354,17 @@
|
||||||
},
|
},
|
||||||
"return_full_text": {
|
"return_full_text": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": "None",
|
"default": "null",
|
||||||
"example": false
|
"example": false,
|
||||||
|
"nullable": true
|
||||||
},
|
},
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64"
|
"format": "int64",
|
||||||
|
"default": "null",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
"stop": {
|
"stop": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
@ -334,6 +401,21 @@
|
||||||
"maximum": 1.0,
|
"maximum": 1.0,
|
||||||
"exclusiveMinimum": 0.0
|
"exclusiveMinimum": 0.0
|
||||||
},
|
},
|
||||||
|
"truncate": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": "null",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
|
"typical_p": {
|
||||||
|
"type": "number",
|
||||||
|
"format": "float",
|
||||||
|
"default": "null",
|
||||||
|
"example": 0.95,
|
||||||
|
"nullable": true,
|
||||||
|
"maximum": 1.0,
|
||||||
|
"exclusiveMinimum": 0.0
|
||||||
|
},
|
||||||
"watermark": {
|
"watermark": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": "false",
|
"default": "false",
|
||||||
|
@ -414,7 +496,8 @@
|
||||||
"seed": {
|
"seed": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"format": "int64",
|
"format": "int64",
|
||||||
"example": 42
|
"example": 42,
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -31,6 +31,8 @@ struct Args {
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
|
@ -86,6 +88,7 @@ fn main() -> ExitCode {
|
||||||
num_shard,
|
num_shard,
|
||||||
quantize,
|
quantize,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
@ -363,6 +366,8 @@ fn main() -> ExitCode {
|
||||||
"text-generation-router".to_string(),
|
"text-generation-router".to_string(),
|
||||||
"--max-concurrent-requests".to_string(),
|
"--max-concurrent-requests".to_string(),
|
||||||
max_concurrent_requests.to_string(),
|
max_concurrent_requests.to_string(),
|
||||||
|
"--max-best-of".to_string(),
|
||||||
|
max_best_of.to_string(),
|
||||||
"--max-stop-sequences".to_string(),
|
"--max-stop-sequences".to_string(),
|
||||||
max_stop_sequences.to_string(),
|
max_stop_sequences.to_string(),
|
||||||
"--max-input-length".to_string(),
|
"--max-input-length".to_string(),
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
|
use futures::future::try_join_all;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
|
@ -177,6 +178,43 @@ impl Infer {
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
|
||||||
|
/// the highest log probability per token
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub(crate) async fn generate_best_of(
|
||||||
|
&self,
|
||||||
|
request: GenerateRequest,
|
||||||
|
best_of: usize,
|
||||||
|
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
|
||||||
|
// validate best_of parameter separately
|
||||||
|
let best_of = self.validation.validate_best_of(best_of)?;
|
||||||
|
|
||||||
|
// create multiple generate requests
|
||||||
|
let mut infer_responses: Vec<InferResponse> =
|
||||||
|
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
|
||||||
|
|
||||||
|
// get the sequence with the highest log probability per token
|
||||||
|
let mut max_index = 0;
|
||||||
|
let mut max_logprob: f32 = f32::MIN;
|
||||||
|
|
||||||
|
for (i, response) in infer_responses.iter().enumerate() {
|
||||||
|
// mean logprobs of the generated tokens
|
||||||
|
let sequence_logprob = response
|
||||||
|
.tokens
|
||||||
|
.iter()
|
||||||
|
.map(|token| token.logprob)
|
||||||
|
.sum::<f32>()
|
||||||
|
/ response.tokens.len() as f32;
|
||||||
|
|
||||||
|
// set best sequence
|
||||||
|
if sequence_logprob > max_logprob {
|
||||||
|
max_index = i;
|
||||||
|
max_logprob = sequence_logprob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let best_response = infer_responses.remove(max_index);
|
||||||
|
Ok((best_response, infer_responses))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
|
|
@ -12,6 +12,9 @@ use validation::Validation;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
|
||||||
|
pub best_of: Option<usize>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(
|
#[schema(
|
||||||
exclusive_minimum = 0.0,
|
exclusive_minimum = 0.0,
|
||||||
|
@ -56,13 +59,13 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: u32,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "null", example = false)]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
#[schema(inline, max_items = 4, example = json ! (["photographer"]))]
|
||||||
pub stop: Vec<String>,
|
pub stop: Vec<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub truncate: Option<usize>,
|
pub truncate: Option<usize>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
|
@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters {
|
||||||
#[schema(default = "true")]
|
#[schema(default = "true")]
|
||||||
pub details: bool,
|
pub details: bool,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
exclusive_minimum = 0,
|
||||||
|
nullable = true,
|
||||||
|
default = "null",
|
||||||
|
example = "null"
|
||||||
|
)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 {
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
GenerateParameters {
|
GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
|
@ -158,16 +168,32 @@ pub(crate) enum FinishReason {
|
||||||
StopSequence,
|
StopSequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct BestOfSequence {
|
||||||
|
#[schema(example = "test")]
|
||||||
|
pub generated_text: String,
|
||||||
|
#[schema(example = "length")]
|
||||||
|
pub finish_reason: FinishReason,
|
||||||
|
#[schema(example = 1)]
|
||||||
|
pub generated_tokens: u32,
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
pub prefill: Vec<PrefillToken>,
|
||||||
|
pub tokens: Vec<Token>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
pub(crate) struct Details {
|
pub(crate) struct Details {
|
||||||
#[schema(example = "length")]
|
#[schema(example = "length")]
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
#[schema(example = 1)]
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
pub prefill: Vec<PrefillToken>,
|
pub prefill: Vec<PrefillToken>,
|
||||||
pub tokens: Vec<Token>,
|
pub tokens: Vec<Token>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of_sequences: Option<Vec<BestOfSequence>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
|
@ -184,7 +210,7 @@ pub(crate) struct StreamDetails {
|
||||||
pub finish_reason: FinishReason,
|
pub finish_reason: FinishReason,
|
||||||
#[schema(example = 1)]
|
#[schema(example = 1)]
|
||||||
pub generated_tokens: u32,
|
pub generated_tokens: u32,
|
||||||
#[schema(example = 42)]
|
#[schema(nullable = true, example = 42)]
|
||||||
pub seed: Option<u64>,
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer};
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(default_value = "128", long, env)]
|
#[clap(default_value = "128", long, env)]
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
#[clap(default_value = "4", long, env)]
|
#[clap(default_value = "4", long, env)]
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
#[clap(default_value = "1000", long, env)]
|
#[clap(default_value = "1000", long, env)]
|
||||||
|
@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
server::run(
|
server::run(
|
||||||
compat_return_full_text,
|
compat_return_full_text,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
|
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
|
||||||
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
|
GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
|
||||||
Validation,
|
StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
|
@ -64,6 +65,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
.generate(GenerateRequest {
|
.generate(GenerateRequest {
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
repetition_penalty: None,
|
repetition_penalty: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
|
@ -128,17 +130,51 @@ async fn generate(
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
// Inference
|
// Inference
|
||||||
let response = infer.generate(req.0).await?;
|
let (response, best_of_responses) = match req.0.parameters.best_of {
|
||||||
|
Some(best_of) if best_of > 1 => {
|
||||||
|
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
|
||||||
|
(response, Some(best_of_responses))
|
||||||
|
}
|
||||||
|
_ => (infer.generate(req.0).await?, None),
|
||||||
|
};
|
||||||
|
|
||||||
// Token details
|
// Token details
|
||||||
let details = match details {
|
let details = match details {
|
||||||
true => Some(Details {
|
true => {
|
||||||
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
// convert best_of_responses
|
||||||
generated_tokens: response.generated_text.generated_tokens,
|
let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
|
||||||
prefill: response.prefill,
|
responses
|
||||||
tokens: response.tokens,
|
.into_iter()
|
||||||
seed: response.generated_text.seed,
|
.map(|response: InferResponse| {
|
||||||
}),
|
// Add prompt if return_full_text
|
||||||
|
let mut output_text = response.generated_text.text;
|
||||||
|
if let Some(prompt) = &add_prompt {
|
||||||
|
output_text = prompt.clone() + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
BestOfSequence {
|
||||||
|
generated_text: output_text,
|
||||||
|
finish_reason: FinishReason::from(
|
||||||
|
response.generated_text.finish_reason,
|
||||||
|
),
|
||||||
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
|
prefill: response.prefill,
|
||||||
|
tokens: response.tokens,
|
||||||
|
seed: response.generated_text.seed,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
});
|
||||||
|
|
||||||
|
Some(Details {
|
||||||
|
finish_reason: FinishReason::from(response.generated_text.finish_reason),
|
||||||
|
generated_tokens: response.generated_text.generated_tokens,
|
||||||
|
prefill: response.prefill,
|
||||||
|
tokens: response.tokens,
|
||||||
|
seed: response.generated_text.seed,
|
||||||
|
best_of_sequences,
|
||||||
|
})
|
||||||
|
}
|
||||||
false => None,
|
false => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -279,107 +315,115 @@ async fn generate_stream(
|
||||||
}
|
}
|
||||||
let details = req.0.parameters.details;
|
let details = req.0.parameters.details;
|
||||||
|
|
||||||
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
let best_of = req.0.parameters.best_of.unwrap_or(1);
|
||||||
Ok(mut response_stream) => {
|
if best_of == 1 {
|
||||||
// Server-Sent Event stream
|
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||||
while let Some(response) = response_stream.next().await {
|
Ok(mut response_stream) => {
|
||||||
match response {
|
// Server-Sent Event stream
|
||||||
Ok(response) => {
|
while let Some(response) = response_stream.next().await {
|
||||||
match response {
|
match response {
|
||||||
// Prefill is ignored
|
Ok(response) => {
|
||||||
InferStreamResponse::Prefill(_) => {}
|
match response {
|
||||||
// Yield event for every new token
|
// Prefill is ignored
|
||||||
InferStreamResponse::Token(token) => {
|
InferStreamResponse::Prefill(_) => {}
|
||||||
// StreamResponse
|
// Yield event for every new token
|
||||||
let stream_token = StreamResponse {
|
InferStreamResponse::Token(token) => {
|
||||||
token,
|
// StreamResponse
|
||||||
generated_text: None,
|
let stream_token = StreamResponse {
|
||||||
details: None,
|
token,
|
||||||
};
|
generated_text: None,
|
||||||
|
details: None,
|
||||||
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap())
|
yield Ok(Event::default().json_data(stream_token).unwrap())
|
||||||
}
|
|
||||||
// Yield event for last token and compute timings
|
|
||||||
InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
generated_text,
|
|
||||||
start,
|
|
||||||
queued,
|
|
||||||
} => {
|
|
||||||
// Token details
|
|
||||||
let details = match details {
|
|
||||||
true => Some(StreamDetails {
|
|
||||||
finish_reason: FinishReason::from(generated_text.finish_reason),
|
|
||||||
generated_tokens: generated_text.generated_tokens,
|
|
||||||
seed: generated_text.seed,
|
|
||||||
}),
|
|
||||||
false => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Timings
|
|
||||||
let total_time = start_time.elapsed();
|
|
||||||
let validation_time = queued - start_time;
|
|
||||||
let queue_time = start - queued;
|
|
||||||
let inference_time = Instant::now() - start;
|
|
||||||
let time_per_token = inference_time / generated_text.generated_tokens;
|
|
||||||
|
|
||||||
// Tracing metadata
|
|
||||||
span.record("total_time", format!("{total_time:?}"));
|
|
||||||
span.record("validation_time", format!("{validation_time:?}"));
|
|
||||||
span.record("queue_time", format!("{queue_time:?}"));
|
|
||||||
span.record("inference_time", format!("{inference_time:?}"));
|
|
||||||
span.record("time_per_token", format!("{time_per_token:?}"));
|
|
||||||
span.record("seed", format!("{:?}", generated_text.seed));
|
|
||||||
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
|
||||||
|
|
||||||
// Metrics
|
|
||||||
metrics::increment_counter!("tgi_request_success");
|
|
||||||
metrics::histogram!("tgi_request_duration", total_time);
|
|
||||||
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
|
||||||
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
|
||||||
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
|
||||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
|
||||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
|
||||||
|
|
||||||
// StreamResponse
|
|
||||||
end_reached = true;
|
|
||||||
|
|
||||||
let mut output_text = generated_text.text;
|
|
||||||
if let Some(prompt) = add_prompt {
|
|
||||||
output_text = prompt + &output_text;
|
|
||||||
}
|
}
|
||||||
|
// Yield event for last token and compute timings
|
||||||
let stream_token = StreamResponse {
|
InferStreamResponse::End {
|
||||||
token,
|
token,
|
||||||
generated_text: Some(output_text),
|
generated_text,
|
||||||
details
|
start,
|
||||||
};
|
queued,
|
||||||
|
} => {
|
||||||
|
// Token details
|
||||||
|
let details = match details {
|
||||||
|
true => Some(StreamDetails {
|
||||||
|
finish_reason: FinishReason::from(generated_text.finish_reason),
|
||||||
|
generated_tokens: generated_text.generated_tokens,
|
||||||
|
seed: generated_text.seed,
|
||||||
|
}),
|
||||||
|
false => None,
|
||||||
|
};
|
||||||
|
|
||||||
yield Ok(Event::default().json_data(stream_token).unwrap());
|
// Timings
|
||||||
break;
|
let total_time = start_time.elapsed();
|
||||||
|
let validation_time = queued - start_time;
|
||||||
|
let queue_time = start - queued;
|
||||||
|
let inference_time = Instant::now() - start;
|
||||||
|
let time_per_token = inference_time / generated_text.generated_tokens;
|
||||||
|
|
||||||
|
// Tracing metadata
|
||||||
|
span.record("total_time", format!("{total_time:?}"));
|
||||||
|
span.record("validation_time", format!("{validation_time:?}"));
|
||||||
|
span.record("queue_time", format!("{queue_time:?}"));
|
||||||
|
span.record("inference_time", format!("{inference_time:?}"));
|
||||||
|
span.record("time_per_token", format!("{time_per_token:?}"));
|
||||||
|
span.record("seed", format!("{:?}", generated_text.seed));
|
||||||
|
tracing::info!(parent: &span, "Output: {}", generated_text.text);
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
metrics::increment_counter!("tgi_request_success");
|
||||||
|
metrics::histogram!("tgi_request_duration", total_time);
|
||||||
|
metrics::histogram!("tgi_request_validation_duration", validation_time);
|
||||||
|
metrics::histogram!("tgi_request_queue_duration", queue_time);
|
||||||
|
metrics::histogram!("tgi_request_inference_duration", inference_time);
|
||||||
|
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
|
||||||
|
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||||
|
|
||||||
|
// StreamResponse
|
||||||
|
end_reached = true;
|
||||||
|
|
||||||
|
let mut output_text = generated_text.text;
|
||||||
|
if let Some(prompt) = add_prompt {
|
||||||
|
output_text = prompt + &output_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream_token = StreamResponse {
|
||||||
|
token,
|
||||||
|
generated_text: Some(output_text),
|
||||||
|
details
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default().json_data(stream_token).unwrap());
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
// yield error
|
||||||
// yield error
|
Err(err) => {
|
||||||
Err(err) => {
|
error = true;
|
||||||
error = true;
|
yield Ok(Event::from(err));
|
||||||
yield Ok(Event::from(err));
|
break;
|
||||||
break;
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
// yield error
|
||||||
|
Err(err) => {
|
||||||
|
error = true;
|
||||||
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
// yield error
|
// Check if generation reached the end
|
||||||
Err(err) => {
|
// Skip if we already sent an error
|
||||||
error = true;
|
if !end_reached && !error {
|
||||||
|
let err = InferError::IncompleteGeneration;
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||||
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
// Check if generation reached the end
|
let err = InferError::from(ValidationError::BestOfStream);
|
||||||
// Skip if we already sent an error
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
if !end_reached && !error {
|
|
||||||
let err = InferError::IncompleteGeneration;
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
|
||||||
tracing::error!("{err}");
|
tracing::error!("{err}");
|
||||||
yield Ok(Event::from(err));
|
yield Ok(Event::from(err));
|
||||||
}
|
}
|
||||||
|
@ -404,6 +448,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
compat_return_full_text: bool,
|
compat_return_full_text: bool,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
@ -430,6 +475,7 @@ pub async fn run(
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
|
BestOfSequence,
|
||||||
Details,
|
Details,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
|
@ -454,6 +500,7 @@ pub async fn run(
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
validation_workers,
|
validation_workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
max_stop_sequences,
|
max_stop_sequences,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::validation::ValidationError::EmptyInput;
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::rngs::ThreadRng;
|
||||||
|
@ -13,6 +13,9 @@ use tracing::{instrument, Span};
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
|
/// maximum value for the best_of parameter
|
||||||
|
#[allow(dead_code)]
|
||||||
|
max_best_of: usize,
|
||||||
/// Channel to communicate with the background validation task
|
/// Channel to communicate with the background validation task
|
||||||
sender: mpsc::UnboundedSender<ValidationRequest>,
|
sender: mpsc::UnboundedSender<ValidationRequest>,
|
||||||
}
|
}
|
||||||
|
@ -21,6 +24,7 @@ impl Validation {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
@ -39,6 +43,7 @@ impl Validation {
|
||||||
));
|
));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
max_best_of,
|
||||||
sender: validation_sender,
|
sender: validation_sender,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,6 +65,20 @@ impl Validation {
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
receiver.await.unwrap()
|
receiver.await.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate the best_of parameter
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
|
||||||
|
if self.max_best_of == 1 && best_of != 1 {
|
||||||
|
return Err(ValidationError::BestOfDisabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
if best_of > self.max_best_of {
|
||||||
|
return Err(ValidationError::BestOf(self.max_best_of, best_of));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(best_of)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validation task
|
/// Validation task
|
||||||
|
@ -150,6 +169,7 @@ fn validate(
|
||||||
rng: &mut ThreadRng,
|
rng: &mut ThreadRng,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
let GenerateParameters {
|
let GenerateParameters {
|
||||||
|
best_of,
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
top_k,
|
top_k,
|
||||||
|
@ -164,6 +184,18 @@ fn validate(
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
// sampling must be true when best_of > 1
|
||||||
|
let best_of = best_of.unwrap_or(1);
|
||||||
|
let sampling = do_sample
|
||||||
|
|| temperature.is_some()
|
||||||
|
|| top_k.is_some()
|
||||||
|
|| top_p.is_some()
|
||||||
|
|| typical_p.is_some();
|
||||||
|
|
||||||
|
if best_of > 1 && !sampling {
|
||||||
|
return Err(BestOfSampling);
|
||||||
|
}
|
||||||
|
|
||||||
let temperature = temperature.unwrap_or(1.0);
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
if temperature <= 0.0 {
|
if temperature <= 0.0 {
|
||||||
return Err(ValidationError::Temperature);
|
return Err(ValidationError::Temperature);
|
||||||
|
@ -217,7 +249,12 @@ fn validate(
|
||||||
// If seed is None, assign a random one
|
// If seed is None, assign a random one
|
||||||
let seed = match seed {
|
let seed = match seed {
|
||||||
None => rng.gen(),
|
None => rng.gen(),
|
||||||
Some(seed) => seed,
|
Some(seed) => {
|
||||||
|
if best_of > 1 {
|
||||||
|
return Err(BestOfSeed);
|
||||||
|
}
|
||||||
|
seed
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if inputs is empty
|
// Check if inputs is empty
|
||||||
|
@ -307,6 +344,16 @@ pub(crate) struct ValidGenerateRequest {
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum ValidationError {
|
pub enum ValidationError {
|
||||||
|
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
|
||||||
|
BestOf(usize, usize),
|
||||||
|
#[error("`best_of` != 1 is not allowed for this endpoint")]
|
||||||
|
BestOfDisabled,
|
||||||
|
#[error("you must use sampling when `best_of` is > 1")]
|
||||||
|
BestOfSampling,
|
||||||
|
#[error("`seed` must not be set when `best_of` > 1")]
|
||||||
|
BestOfSeed,
|
||||||
|
#[error("`best_of` != 1 is not supported when streaming tokens")]
|
||||||
|
BestOfStream,
|
||||||
#[error("`temperature` must be strictly positive")]
|
#[error("`temperature` must be strictly positive")]
|
||||||
Temperature,
|
Temperature,
|
||||||
#[error("`repetition_penalty` must be strictly positive")]
|
#[error("`repetition_penalty` must be strictly positive")]
|
||||||
|
|
Loading…
Reference in New Issue