Add stop parameter to completions route

This commit is contained in:
Thomas SCHILLACI 2024-05-07 19:27:05 +02:00
parent 9fb1cdc8d5
commit 2f644779cb
3 changed files with 34 additions and 7 deletions

View File

@ -1121,6 +1121,15 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
}
}
},

View File

@ -401,6 +401,11 @@ pub struct CompletionRequest {
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]

View File

@ -597,9 +597,22 @@ async fn completions(
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
let stream = req.stream;
let max_new_tokens = req.max_tokens.or(Some(100));
let seed = req.seed;
let CompletionRequest {
max_tokens,
seed,
stop,
stream,
temperature,
..
} = req;
let max_new_tokens = max_tokens.or(Some(100));
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// if suffix is present throw an error
if req.suffix.is_some() {
@ -629,22 +642,22 @@ async fn completions(
}
let generate_requests: Vec<GenerateRequest> = req
.prompt
.prompt
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
parameters: GenerateParameters {
best_of: None,
temperature: req.temperature,
temperature: temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: true,
do_sample,
max_new_tokens,
return_full_text: None,
stop: Vec::new(),
stop: stop.clone(),
truncate: None,
watermark: false,
details: true,