Add stop parameter to completions route
This commit is contained in:
parent
9fb1cdc8d5
commit
2f644779cb
|
@ -1121,6 +1121,15 @@
|
||||||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||||
"example": 0.95,
|
"example": 0.95,
|
||||||
"nullable": true
|
"nullable": true
|
||||||
|
},
|
||||||
|
"stop": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||||
|
"example": "null",
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -401,6 +401,11 @@ pub struct CompletionRequest {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(example = "1.0")]
|
#[schema(example = "1.0")]
|
||||||
pub frequency_penalty: Option<f32>,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||||
|
|
|
@ -597,9 +597,22 @@ async fn completions(
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let stream = req.stream;
|
let CompletionRequest {
|
||||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
max_tokens,
|
||||||
let seed = req.seed;
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
temperature,
|
||||||
|
..
|
||||||
|
} = req;
|
||||||
|
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
|
||||||
// if suffix is present throw an error
|
// if suffix is present throw an error
|
||||||
if req.suffix.is_some() {
|
if req.suffix.is_some() {
|
||||||
|
@ -635,16 +648,16 @@ async fn completions(
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: temperature,
|
||||||
repetition_penalty: req.repetition_penalty,
|
repetition_penalty: req.repetition_penalty,
|
||||||
frequency_penalty: req.frequency_penalty,
|
frequency_penalty: req.frequency_penalty,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
top_p: req.top_p,
|
top_p: req.top_p,
|
||||||
typical_p: None,
|
typical_p: None,
|
||||||
do_sample: true,
|
do_sample,
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
return_full_text: None,
|
return_full_text: None,
|
||||||
stop: Vec::new(),
|
stop: stop.clone(),
|
||||||
truncate: None,
|
truncate: None,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
details: true,
|
details: true,
|
||||||
|
|
Loading…
Reference in New Issue