Unroll notify error into generate response (#2597)

* feat: unroll notify_error if no tool is choosen

* fix: expect simple message when no tool is selected

* fix: improve test to avoid notify_error

* fix: improve docs and indicate change in expected response

* fix: adjust linting in test file
This commit is contained in:
drbh 2024-10-02 11:34:57 -04:00 committed by GitHub
parent 2335459556
commit d22b0c1fbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 44 deletions

View File

@ -311,11 +311,13 @@ print(chat.choices[0].message.tool_calls)
```
### OpenAI integration
### OpenAI Integration
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
Text Generation Inference (TGI) offers seamless integration with OpenAI's client libraries, allowing developers to interact with TGI's Messages API and Tool functions in a familiar way. This compatibility simplifies the implementation of advanced features, such as tools and grammar, within your applications using OpenAIs client.
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
Previously, TGI handled tool selection differently than OpenAIs API—`tool_choice="auto"` would always pick a tool for you. However, as of the latest version, TGI now mimics OpenAIs behavior more closely: `tool_choice="auto"` selects a tool only when the model deems it necessary, aligning with how OpenAI's API works. This enhancement ensures a smoother and more predictable integration experience.
Additionally, error notifications like `notify_error`, which previously indicated that no tool was chosen, are no longer returned. Instead, TGI will proceed with generating a response as if no tool was selected, further improving consistency with OpenAI's API.
```python
from openai import OpenAI

View File

@ -1,38 +1,26 @@
{
"choices": [
{
"finish_reason": "eos_token",
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"content": "There is a huge storm in the ocean",
"name": null,
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": {
"error": "Cannot get current weather forecast from specified location and temperature unit. Please try again with different options."
},
"description": null,
"name": "notify_error"
},
"id": 0,
"type": "function"
}
]
"tool_calls": null
},
"usage": null
}
],
"created": 1712852597,
"created": 1727796440,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.5-native",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 39,
"prompt_tokens": 496,
"total_tokens": 535
"completion_tokens": 25,
"prompt_tokens": 600,
"total_tokens": 625
}
}

View File

@ -225,10 +225,6 @@ async def test_flash_llama_grammar_tools_insufficient_information(
tools=tools,
tool_choice="auto",
messages=[
{
"role": "system",
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
},
{
"role": "user",
"content": "Tell me a story about 3 sea creatures",
@ -237,8 +233,5 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False,
)
assert responses.choices[0].message.content is None
assert (
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
)
assert responses.choices[0].message.content == "There is a huge storm in the ocean"
assert responses == response_snapshot

View File

@ -1246,17 +1246,33 @@ async fn chat_completions(
if let Value::Object(ref mut props) = arguments {
props.remove("_name");
}
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
match name.as_str() {
"notify_error" => {
// parse the error message
let error_message = arguments
.get("error")
.and_then(Value::as_str)
.ok_or_else(|| {
InferError::ToolError(
"No error message found in generated text".to_string(),
)
})?
.to_string();
(None, Some(error_message))
}
_ => {
let tool_calls = vec![ToolCall {
id: "0".to_string(),
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name,
arguments,
},
}];
(Some(tool_calls), None)
}
}
} else {
(None, Some(generation.generated_text))
};