From d8402eaf6723818eec2d8abf7715b9dc42da07df Mon Sep 17 00:00:00 2001 From: phangiabao98 <60313144+phangiabao98@users.noreply.github.com> Date: Thu, 16 May 2024 15:17:00 +0700 Subject: [PATCH] OpenAI function calling compatible support (#1888) # What does this PR do? Fixes # (issue) https://github.com/huggingface/text-generation-inference/issues/1887 ## Before submitting - [no ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [yes] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ yes] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [yes ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ yes] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil --> --------- Co-authored-by: Bao Phan --- router/src/infer.rs | 23 +++++++++++++++++++++++ router/src/lib.rs | 8 ++++++-- router/src/server.rs | 3 +-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index a4945512..bfa7b55c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -979,24 +979,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1049,30 +1053,35 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("Hi again!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1130,24 +1139,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1189,24 +1202,28 @@ mod tests { content: Some("Hi!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("Hello how can I help?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("What is Deep Learning?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("magic!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], bos_token: Some("[BOS]"), @@ -1234,18 +1251,21 @@ mod tests { content: Some("Hello, how are you?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "assistant".to_string(), content: Some("I'm doing great. How can I help you today?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("I'd like to show off how chat templating works!".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ]; @@ -1257,6 +1277,7 @@ mod tests { ), name: None, tool_calls: None, + tool_call_id: None, }] .iter() .chain(&example_chat) @@ -1401,12 +1422,14 @@ mod tests { content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, Message { role: "user".to_string(), content: Some("How many helicopters can a human eat in one sitting?".to_string()), name: None, tool_calls: None, + tool_call_id: None, }, ], add_generation_prompt: true, diff --git a/router/src/lib.rs b/router/src/lib.rs index 96a9fdf6..85e18dfb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -546,6 +546,7 @@ impl ChatCompletion { content: output, name: None, tool_calls, + tool_call_id: None, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> { #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] pub(crate) struct ToolCall { - pub id: u32, + pub id: String, pub r#type: String, pub function: FunctionDefinition, } @@ -954,13 +955,16 @@ pub(crate) struct Message { pub role: String, #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] - #[serde(deserialize_with = "message_content_serde::deserialize")] + #[serde(default, deserialize_with = "message_content_serde::deserialize")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(example = "\"get_weather\"")] + pub tool_call_id: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index cb55d897..f51bbbef 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -988,7 +988,6 @@ async fn chat_completions( ) -> Result)> { let span = tracing::Span::current(); metrics::increment_counter!("tgi_request_count"); - let ChatRequest { logprobs, max_tokens, @@ -1160,7 +1159,7 @@ async fn chat_completions( ) })?; let tool_calls = vec![ToolCall { - id: 0, + id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None,