Add stop parameter to completions route

This commit is contained in:
Thomas SCHILLACI 2024-05-07 19:27:05 +02:00
parent 9fb1cdc8d5
commit 2f644779cb
3 changed files with 34 additions and 7 deletions

View File

@ -1121,6 +1121,15 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95, "example": 0.95,
"nullable": true "nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
} }
} }
}, },

View File

@ -401,6 +401,11 @@ pub struct CompletionRequest {
#[serde(default)] #[serde(default)]
#[schema(example = "1.0")] #[schema(example = "1.0")]
pub frequency_penalty: Option<f32>, pub frequency_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default)]

View File

@ -597,9 +597,22 @@ async fn completions(
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let CompletionRequest {
let max_new_tokens = req.max_tokens.or(Some(100)); max_tokens,
let seed = req.seed; seed,
stop,
stream,
temperature,
..
} = req;
let max_new_tokens = max_tokens.or(Some(100));
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// if suffix is present throw an error // if suffix is present throw an error
if req.suffix.is_some() { if req.suffix.is_some() {
@ -629,22 +642,22 @@ async fn completions(
} }
let generate_requests: Vec<GenerateRequest> = req let generate_requests: Vec<GenerateRequest> = req
.prompt .prompt
.iter() .iter()
.map(|prompt| GenerateRequest { .map(|prompt| GenerateRequest {
inputs: prompt.to_string(), inputs: prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature: temperature,
repetition_penalty: req.repetition_penalty, repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty, frequency_penalty: req.frequency_penalty,
top_k: None, top_k: None,
top_p: req.top_p, top_p: req.top_p,
typical_p: None, typical_p: None,
do_sample: true, do_sample,
max_new_tokens, max_new_tokens,
return_full_text: None, return_full_text: None,
stop: Vec::new(), stop: stop.clone(),
truncate: None, truncate: None,
watermark: false, watermark: false,
details: true, details: true,