Add stop parameter to completions route
This commit is contained in:
parent
9fb1cdc8d5
commit
2f644779cb
|
@ -1121,6 +1121,15 @@
|
|||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||
"example": 0.95,
|
||||
"nullable": true
|
||||
},
|
||||
"stop": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -401,6 +401,11 @@ pub struct CompletionRequest {
|
|||
#[serde(default)]
|
||||
#[schema(example = "1.0")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
|
|
|
@ -597,9 +597,22 @@ async fn completions(
|
|||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let seed = req.seed;
|
||||
let CompletionRequest {
|
||||
max_tokens,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
temperature,
|
||||
..
|
||||
} = req;
|
||||
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
|
||||
// if suffix is present throw an error
|
||||
if req.suffix.is_some() {
|
||||
|
@ -629,22 +642,22 @@ async fn completions(
|
|||
}
|
||||
|
||||
let generate_requests: Vec<GenerateRequest> = req
|
||||
.prompt
|
||||
.prompt
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature: req.temperature,
|
||||
temperature: temperature,
|
||||
repetition_penalty: req.repetition_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
top_k: None,
|
||||
top_p: req.top_p,
|
||||
typical_p: None,
|
||||
do_sample: true,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop: Vec::new(),
|
||||
stop: stop.clone(),
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
|
|
Loading…
Reference in New Issue