fix: adjust tool grammar ownership
This commit is contained in:
parent
bb73acc1a9
commit
9874b15fa8
|
@ -756,19 +756,34 @@ async fn chat_completions(
|
|||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
|
||||
let stream = req.stream;
|
||||
let max_new_tokens = req.max_tokens.or(Some(100));
|
||||
let repetition_penalty = req
|
||||
.presence_penalty
|
||||
// rescale repetition_penalty from (-2.0, 2.0) to (0.0, 4.0)
|
||||
.map(|x| x + 2.0);
|
||||
let logprobs = req.logprobs.unwrap_or(false);
|
||||
let seed = req.seed;
|
||||
let stop = req.stop.unwrap_or_default();
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let ChatRequest {
|
||||
frequency_penalty: _,
|
||||
logit_bias: _,
|
||||
logprobs,
|
||||
max_tokens,
|
||||
messages,
|
||||
model: _,
|
||||
n: _,
|
||||
presence_penalty,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
temperature: _,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
top_p: _,
|
||||
top_logprobs: _,
|
||||
} = req;
|
||||
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let logprobs = logprobs.unwrap_or(false);
|
||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let stop = stop.unwrap_or_default();
|
||||
|
||||
// extract tool grammar if present
|
||||
let tool_grammar = match ToolGrammar::apply(req.tools.as_ref(), req.tool_choice.as_ref()) {
|
||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
|
@ -784,7 +799,7 @@ async fn chat_completions(
|
|||
};
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
let mut inputs = match infer.apply_chat_template(messages) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
|
|
Loading…
Reference in New Issue