feat(router): add best_of parameter (#117)

This commit is contained in:
OlivierDehaene 2023-03-09 15:30:54 +01:00 committed by GitHub
parent e8bfe199ba
commit 55bd4fed7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 365 additions and 115 deletions

View File

@ -210,13 +210,63 @@
}, },
"components": { "components": {
"schemas": { "schemas": {
"BestOfSequence": {
"type": "object",
"required": [
"generated_text",
"finish_reason",
"generated_tokens",
"prefill",
"tokens"
],
"properties": {
"finish_reason": {
"$ref": "#/components/schemas/FinishReason"
},
"generated_text": {
"type": "string",
"example": "test"
},
"generated_tokens": {
"type": "integer",
"format": "int32",
"example": 1
},
"prefill": {
"type": "array",
"items": {
"$ref": "#/components/schemas/PrefillToken"
}
},
"seed": {
"type": "integer",
"format": "int64",
"example": 42,
"nullable": true
},
"tokens": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Token"
}
}
}
},
"Details": { "Details": {
"type": "object", "type": "object",
"required": [ "required": [
"finish_reason", "finish_reason",
"generated_tokens" "generated_tokens",
"prefill",
"tokens"
], ],
"properties": { "properties": {
"best_of_sequences": {
"type": "array",
"items": {
"$ref": "#/components/schemas/BestOfSequence"
}
},
"finish_reason": { "finish_reason": {
"$ref": "#/components/schemas/FinishReason" "$ref": "#/components/schemas/FinishReason"
}, },
@ -234,7 +284,8 @@
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42 "example": 42,
"nullable": true
}, },
"tokens": { "tokens": {
"type": "array", "type": "array",
@ -247,11 +298,15 @@
"ErrorResponse": { "ErrorResponse": {
"type": "object", "type": "object",
"required": [ "required": [
"error" "error",
"error_type"
], ],
"properties": { "properties": {
"error": { "error": {
"type": "string" "type": "string"
},
"error_type": {
"type": "string"
} }
} }
}, },
@ -266,6 +321,13 @@
"GenerateParameters": { "GenerateParameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"best_of": {
"type": "integer",
"default": "null",
"example": 1,
"nullable": true,
"exclusiveMinimum": 0.0
},
"details": { "details": {
"type": "boolean", "type": "boolean",
"default": "true" "default": "true"
@ -292,12 +354,17 @@
}, },
"return_full_text": { "return_full_text": {
"type": "boolean", "type": "boolean",
"default": "None", "default": "null",
"example": false "example": false,
"nullable": true
}, },
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64" "format": "int64",
"default": "null",
"example": "null",
"nullable": true,
"exclusiveMinimum": 0.0
}, },
"stop": { "stop": {
"type": "array", "type": "array",
@ -334,6 +401,21 @@
"maximum": 1.0, "maximum": 1.0,
"exclusiveMinimum": 0.0 "exclusiveMinimum": 0.0
}, },
"truncate": {
"type": "integer",
"default": "null",
"example": "null",
"nullable": true
},
"typical_p": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.95,
"nullable": true,
"maximum": 1.0,
"exclusiveMinimum": 0.0
},
"watermark": { "watermark": {
"type": "boolean", "type": "boolean",
"default": "false", "default": "false",
@ -414,7 +496,8 @@
"seed": { "seed": {
"type": "integer", "type": "integer",
"format": "int64", "format": "int64",
"example": 42 "example": 42,
"nullable": true
} }
} }
}, },

View File

@ -31,6 +31,8 @@ struct Args {
quantize: bool, quantize: bool,
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
@ -86,6 +88,7 @@ fn main() -> ExitCode {
num_shard, num_shard,
quantize, quantize,
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
@ -363,6 +366,8 @@ fn main() -> ExitCode {
"text-generation-router".to_string(), "text-generation-router".to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
max_concurrent_requests.to_string(), max_concurrent_requests.to_string(),
"--max-best-of".to_string(),
max_best_of.to_string(),
"--max-stop-sequences".to_string(), "--max-stop-sequences".to_string(),
max_stop_sequences.to_string(), max_stop_sequences.to_string(),
"--max-input-length".to_string(), "--max-input-length".to_string(),

View File

@ -2,6 +2,7 @@
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{Entry, Queue, Token}; use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken}; use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ use text_generation_client::{
@ -177,6 +178,43 @@ impl Infer {
Err(err) Err(err)
} }
} }
/// Add best_of new requests to the queue and return a InferResponse of the sequence with
/// the highest log probability per token
#[instrument(skip(self))]
pub(crate) async fn generate_best_of(
&self,
request: GenerateRequest,
best_of: usize,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;
// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
// get the sequence with the highest log probability per token
let mut max_index = 0;
let mut max_logprob: f32 = f32::MIN;
for (i, response) in infer_responses.iter().enumerate() {
// mean logprobs of the generated tokens
let sequence_logprob = response
.tokens
.iter()
.map(|token| token.logprob)
.sum::<f32>()
/ response.tokens.len() as f32;
// set best sequence
if sequence_logprob > max_logprob {
max_index = i;
max_logprob = sequence_logprob;
}
}
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
} }
/// Batching logic /// Batching logic

View File

@ -12,6 +12,9 @@ use validation::Validation;
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
pub(crate) struct GenerateParameters { pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
pub best_of: Option<usize>,
#[serde(default)] #[serde(default)]
#[schema( #[schema(
exclusive_minimum = 0.0, exclusive_minimum = 0.0,
@ -56,13 +59,13 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32, pub max_new_tokens: u32,
#[serde(default)] #[serde(default)]
#[schema(default = "null", example = false)] #[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>, pub return_full_text: Option<bool>,
#[serde(default)] #[serde(default)]
#[schema(inline, max_items = 4, example = json ! (["photographer"]))] #[schema(inline, max_items = 4, example = json ! (["photographer"]))]
pub stop: Vec<String>, pub stop: Vec<String>,
#[serde(default)] #[serde(default)]
#[schema(default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub truncate: Option<usize>, pub truncate: Option<usize>,
#[serde(default)] #[serde(default)]
#[schema(default = "false", example = true)] #[schema(default = "false", example = true)]
@ -71,6 +74,12 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")] #[schema(default = "true")]
pub details: bool, pub details: bool,
#[serde(default)] #[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
default = "null",
example = "null"
)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }
@ -80,6 +89,7 @@ fn default_max_new_tokens() -> u32 {
fn default_parameters() -> GenerateParameters { fn default_parameters() -> GenerateParameters {
GenerateParameters { GenerateParameters {
best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
@ -158,16 +168,32 @@ pub(crate) enum FinishReason {
StopSequence, StopSequence,
} }
#[derive(Serialize, ToSchema)]
pub(crate) struct BestOfSequence {
#[schema(example = "test")]
pub generated_text: String,
#[schema(example = "length")]
pub finish_reason: FinishReason,
#[schema(example = 1)]
pub generated_tokens: u32,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>,
}
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
pub(crate) struct Details { pub(crate) struct Details {
#[schema(example = "length")] #[schema(example = "length")]
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)] #[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
pub prefill: Vec<PrefillToken>, pub prefill: Vec<PrefillToken>,
pub tokens: Vec<Token>, pub tokens: Vec<Token>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of_sequences: Option<Vec<BestOfSequence>>,
} }
#[derive(Serialize, ToSchema)] #[derive(Serialize, ToSchema)]
@ -184,7 +210,7 @@ pub(crate) struct StreamDetails {
pub finish_reason: FinishReason, pub finish_reason: FinishReason,
#[schema(example = 1)] #[schema(example = 1)]
pub generated_tokens: u32, pub generated_tokens: u32,
#[schema(example = 42)] #[schema(nullable = true, example = 42)]
pub seed: Option<u64>, pub seed: Option<u64>,
} }

View File

@ -23,6 +23,8 @@ use tracing_subscriber::{EnvFilter, Layer};
struct Args { struct Args {
#[clap(default_value = "128", long, env)] #[clap(default_value = "128", long, env)]
max_concurrent_requests: usize, max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)] #[clap(default_value = "4", long, env)]
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "1000", long, env)] #[clap(default_value = "1000", long, env)]
@ -55,6 +57,7 @@ fn main() -> Result<(), std::io::Error> {
// Pattern match configuration // Pattern match configuration
let Args { let Args {
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
@ -145,6 +148,7 @@ fn main() -> Result<(), std::io::Error> {
server::run( server::run(
compat_return_full_text, compat_return_full_text,
max_concurrent_requests, max_concurrent_requests,
max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,

View File

@ -1,9 +1,10 @@
/// HTTP Server logic /// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{ use crate::{
CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters, BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token, GenerateParameters, GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails,
Validation, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
@ -64,6 +65,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate(GenerateRequest { .generate(GenerateRequest {
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None,
temperature: None, temperature: None,
repetition_penalty: None, repetition_penalty: None,
top_k: None, top_k: None,
@ -128,17 +130,51 @@ async fn generate(
let details = req.0.parameters.details; let details = req.0.parameters.details;
// Inference // Inference
let response = infer.generate(req.0).await?; let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
};
// Token details // Token details
let details = match details { let details = match details {
true => Some(Details { true => {
finish_reason: FinishReason::from(response.generated_text.finish_reason), // convert best_of_responses
generated_tokens: response.generated_text.generated_tokens, let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
prefill: response.prefill, responses
tokens: response.tokens, .into_iter()
seed: response.generated_text.seed, .map(|response: InferResponse| {
}), // Add prompt if return_full_text
let mut output_text = response.generated_text.text;
if let Some(prompt) = &add_prompt {
output_text = prompt.clone() + &output_text;
}
BestOfSequence {
generated_text: output_text,
finish_reason: FinishReason::from(
response.generated_text.finish_reason,
),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
}
})
.collect()
});
Some(Details {
finish_reason: FinishReason::from(response.generated_text.finish_reason),
generated_tokens: response.generated_text.generated_tokens,
prefill: response.prefill,
tokens: response.tokens,
seed: response.generated_text.seed,
best_of_sequences,
})
}
false => None, false => None,
}; };
@ -279,107 +315,115 @@ async fn generate_stream(
} }
let details = req.0.parameters.details; let details = req.0.parameters.details;
match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await { let best_of = req.0.parameters.best_of.unwrap_or(1);
Ok(mut response_stream) => { if best_of == 1 {
// Server-Sent Event stream match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
while let Some(response) = response_stream.next().await { Ok(mut response_stream) => {
match response { // Server-Sent Event stream
Ok(response) => { while let Some(response) = response_stream.next().await {
match response { match response {
// Prefill is ignored Ok(response) => {
InferStreamResponse::Prefill(_) => {} match response {
// Yield event for every new token // Prefill is ignored
InferStreamResponse::Token(token) => { InferStreamResponse::Prefill(_) => {}
// StreamResponse // Yield event for every new token
let stream_token = StreamResponse { InferStreamResponse::Token(token) => {
token, // StreamResponse
generated_text: None, let stream_token = StreamResponse {
details: None, token,
}; generated_text: None,
details: None,
};
yield Ok(Event::default().json_data(stream_token).unwrap()) yield Ok(Event::default().json_data(stream_token).unwrap())
}
// Yield event for last token and compute timings
InferStreamResponse::End {
token,
generated_text,
start,
queued,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
// Timings
let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// StreamResponse
end_reached = true;
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
} }
// Yield event for last token and compute timings
let stream_token = StreamResponse { InferStreamResponse::End {
token, token,
generated_text: Some(output_text), generated_text,
details start,
}; queued,
} => {
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: FinishReason::from(generated_text.finish_reason),
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
}),
false => None,
};
yield Ok(Event::default().json_data(stream_token).unwrap()); // Timings
break; let total_time = start_time.elapsed();
let validation_time = queued - start_time;
let queue_time = start - queued;
let inference_time = Instant::now() - start;
let time_per_token = inference_time / generated_text.generated_tokens;
// Tracing metadata
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("seed", format!("{:?}", generated_text.seed));
tracing::info!(parent: &span, "Output: {}", generated_text.text);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time);
metrics::histogram!("tgi_request_validation_duration", validation_time);
metrics::histogram!("tgi_request_queue_duration", queue_time);
metrics::histogram!("tgi_request_inference_duration", inference_time);
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
// StreamResponse
end_reached = true;
let mut output_text = generated_text.text;
if let Some(prompt) = add_prompt {
output_text = prompt + &output_text;
}
let stream_token = StreamResponse {
token,
generated_text: Some(output_text),
details
};
yield Ok(Event::default().json_data(stream_token).unwrap());
break;
}
} }
} }
} // yield error
// yield error Err(err) => {
Err(err) => { error = true;
error = true; yield Ok(Event::from(err));
yield Ok(Event::from(err)); break;
break; }
} }
} }
},
// yield error
Err(err) => {
error = true;
yield Ok(Event::from(err));
} }
}, }
// yield error // Check if generation reached the end
Err(err) => { // Skip if we already sent an error
error = true; if !end_reached && !error {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
} } else {
// Check if generation reached the end let err = InferError::from(ValidationError::BestOfStream);
// Skip if we already sent an error metrics::increment_counter!("tgi_request_failure", "err" => "validation");
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
tracing::error!("{err}"); tracing::error!("{err}");
yield Ok(Event::from(err)); yield Ok(Event::from(err));
} }
@ -404,6 +448,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
pub async fn run( pub async fn run(
compat_return_full_text: bool, compat_return_full_text: bool,
max_concurrent_requests: usize, max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
@ -430,6 +475,7 @@ pub async fn run(
PrefillToken, PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
BestOfSequence,
Details, Details,
FinishReason, FinishReason,
StreamResponse, StreamResponse,
@ -454,6 +500,7 @@ pub async fn run(
let validation = Validation::new( let validation = Validation::new(
validation_workers, validation_workers,
tokenizer, tokenizer,
max_best_of,
max_stop_sequences, max_stop_sequences,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,

View File

@ -1,4 +1,4 @@
use crate::validation::ValidationError::EmptyInput; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
/// Payload validation logic /// Payload validation logic
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest};
use rand::rngs::ThreadRng; use rand::rngs::ThreadRng;
@ -13,6 +13,9 @@ use tracing::{instrument, Span};
/// Validation /// Validation
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Validation { pub struct Validation {
/// maximum value for the best_of parameter
#[allow(dead_code)]
max_best_of: usize,
/// Channel to communicate with the background validation task /// Channel to communicate with the background validation task
sender: mpsc::UnboundedSender<ValidationRequest>, sender: mpsc::UnboundedSender<ValidationRequest>,
} }
@ -21,6 +24,7 @@ impl Validation {
pub(crate) fn new( pub(crate) fn new(
workers: usize, workers: usize,
tokenizer: Tokenizer, tokenizer: Tokenizer,
max_best_of: usize,
max_stop_sequences: usize, max_stop_sequences: usize,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
@ -39,6 +43,7 @@ impl Validation {
)); ));
Self { Self {
max_best_of,
sender: validation_sender, sender: validation_sender,
} }
} }
@ -60,6 +65,20 @@ impl Validation {
// Unwrap is safe here // Unwrap is safe here
receiver.await.unwrap() receiver.await.unwrap()
} }
/// Validate the best_of parameter
#[instrument(skip_all)]
pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {
if self.max_best_of == 1 && best_of != 1 {
return Err(ValidationError::BestOfDisabled);
}
if best_of > self.max_best_of {
return Err(ValidationError::BestOf(self.max_best_of, best_of));
}
Ok(best_of)
}
} }
/// Validation task /// Validation task
@ -150,6 +169,7 @@ fn validate(
rng: &mut ThreadRng, rng: &mut ThreadRng,
) -> Result<ValidGenerateRequest, ValidationError> { ) -> Result<ValidGenerateRequest, ValidationError> {
let GenerateParameters { let GenerateParameters {
best_of,
temperature, temperature,
repetition_penalty, repetition_penalty,
top_k, top_k,
@ -164,6 +184,18 @@ fn validate(
.. ..
} = request.parameters; } = request.parameters;
// sampling must be true when best_of > 1
let best_of = best_of.unwrap_or(1);
let sampling = do_sample
|| temperature.is_some()
|| top_k.is_some()
|| top_p.is_some()
|| typical_p.is_some();
if best_of > 1 && !sampling {
return Err(BestOfSampling);
}
let temperature = temperature.unwrap_or(1.0); let temperature = temperature.unwrap_or(1.0);
if temperature <= 0.0 { if temperature <= 0.0 {
return Err(ValidationError::Temperature); return Err(ValidationError::Temperature);
@ -217,7 +249,12 @@ fn validate(
// If seed is None, assign a random one // If seed is None, assign a random one
let seed = match seed { let seed = match seed {
None => rng.gen(), None => rng.gen(),
Some(seed) => seed, Some(seed) => {
if best_of > 1 {
return Err(BestOfSeed);
}
seed
}
}; };
// Check if inputs is empty // Check if inputs is empty
@ -307,6 +344,16 @@ pub(crate) struct ValidGenerateRequest {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum ValidationError { pub enum ValidationError {
#[error("`best_of` must be > 0 and <= {0}. Given: {1}")]
BestOf(usize, usize),
#[error("`best_of` != 1 is not allowed for this endpoint")]
BestOfDisabled,
#[error("you must use sampling when `best_of` is > 1")]
BestOfSampling,
#[error("`seed` must not be set when `best_of` > 1")]
BestOfSeed,
#[error("`best_of` != 1 is not supported when streaming tokens")]
BestOfStream,
#[error("`temperature` must be strictly positive")] #[error("`temperature` must be strictly positive")]
Temperature, Temperature,
#[error("`repetition_penalty` must be strictly positive")] #[error("`repetition_penalty` must be strictly positive")]