Support tools (#1587)
This work in progress PR begins to add support for tools. Tools relies on grammar support and still has some unsolved challenges. Opening the PR for visibility and feedback
This commit is contained in:
parent
bf700e7eef
commit
9b6db5f793
|
@ -3,7 +3,7 @@ import requests
|
|||
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
from pydantic import ValidationError
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator
|
||||
from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
|
||||
|
||||
from text_generation.types import (
|
||||
StreamResponse,
|
||||
|
@ -11,6 +11,11 @@ from text_generation.types import (
|
|||
Request,
|
||||
Parameters,
|
||||
Grammar,
|
||||
ChatRequest,
|
||||
ChatCompletionChunk,
|
||||
ChatComplete,
|
||||
Message,
|
||||
Tool,
|
||||
)
|
||||
from text_generation.errors import parse_error
|
||||
|
||||
|
@ -59,6 +64,114 @@ class Client:
|
|||
self.cookies = cookies
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
Include log probabilities in the response
|
||||
top_logprobs (`int`):
|
||||
Include the `n` most likely tokens at each step
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
n (`int`):
|
||||
Generate `n` completions
|
||||
presence_penalty (`float`):
|
||||
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
payload = resp.json()
|
||||
if resp.status_code != 200:
|
||||
raise parse_error(resp.status_code, payload)
|
||||
return ChatComplete(**payload)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
def _chat_stream_response(self, request):
|
||||
resp = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=request.dict(),
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
# iterate and print stream
|
||||
for byte_payload in resp.iter_lines():
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -313,6 +426,113 @@ class AsyncClient:
|
|||
self.cookies = cookies
|
||||
self.timeout = ClientTimeout(timeout * 60)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Message],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[List[float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
|
||||
"""
|
||||
Given a list of messages, generate a response asynchronously
|
||||
|
||||
Args:
|
||||
messages (`List[Message]`):
|
||||
List of messages
|
||||
frequency_penalty (`float`):
|
||||
The parameter for frequency penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
logit_bias (`List[float]`):
|
||||
Adjust the likelihood of specified tokens
|
||||
logprobs (`bool`):
|
||||
Include log probabilities in the response
|
||||
top_logprobs (`int`):
|
||||
Include the `n` most likely tokens at each step
|
||||
max_tokens (`int`):
|
||||
Maximum number of generated tokens
|
||||
n (`int`):
|
||||
Generate `n` completions
|
||||
presence_penalty (`float`):
|
||||
The parameter for presence penalty. 0.0 means no penalty. See [this
|
||||
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
|
||||
stream (`bool`):
|
||||
Stream the response
|
||||
seed (`int`):
|
||||
Random sampling seed
|
||||
temperature (`float`):
|
||||
The value used to module the logits distribution.
|
||||
top_p (`float`):
|
||||
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
|
||||
higher are kept for generation
|
||||
tools (`List[Tool]`):
|
||||
List of tools to use
|
||||
tool_choice (`str`):
|
||||
The tool to use
|
||||
|
||||
"""
|
||||
request = ChatRequest(
|
||||
model="tgi",
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if not stream:
|
||||
return await self._chat_single_response(request)
|
||||
else:
|
||||
return self._chat_stream_response(request)
|
||||
|
||||
async def _chat_single_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||
) as resp:
|
||||
payload = await resp.json()
|
||||
if resp.status != 200:
|
||||
raise parse_error(resp.status, payload)
|
||||
return ChatComplete(**payload)
|
||||
|
||||
async def _chat_stream_response(self, request):
|
||||
async with ClientSession(
|
||||
headers=self.headers, cookies=self.cookies, timeout=self.timeout
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/v1/chat/completions", json=request.dict()
|
||||
) as resp:
|
||||
async for byte_payload in resp.content:
|
||||
if byte_payload == b"\n":
|
||||
continue
|
||||
payload = byte_payload.decode("utf-8")
|
||||
if payload.startswith("data:"):
|
||||
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
|
||||
try:
|
||||
response = ChatCompletionChunk(**json_payload)
|
||||
yield response
|
||||
except ValidationError:
|
||||
raise parse_error(resp.status, json_payload)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum
|
||||
from pydantic import BaseModel, validator
|
||||
from typing import Optional, List, Union
|
||||
from typing import Optional, List, Union, Any
|
||||
|
||||
from text_generation.errors import ValidationError
|
||||
|
||||
|
@ -19,6 +19,124 @@ class Grammar(BaseModel):
|
|||
value: Union[str, dict]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
# Id of the tool call
|
||||
id: int
|
||||
# Type of the tool call
|
||||
type: str
|
||||
# Function details of the tool call
|
||||
function: dict
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str]
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
tool_calls: Optional[Any] = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
# Type of the tool
|
||||
type: str
|
||||
# Function details of the tool
|
||||
function: dict
|
||||
|
||||
|
||||
class ChatCompletionComplete(BaseModel):
|
||||
# Index of the chat completion
|
||||
index: int
|
||||
# Message associated with the chat completion
|
||||
message: Message
|
||||
# Log probabilities for the chat completion
|
||||
logprobs: Optional[Any]
|
||||
# Reason for completion
|
||||
finish_reason: str
|
||||
# Usage details of the chat completion
|
||||
usage: Any
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChoiceDeltaToolCall(BaseModel):
|
||||
index: int
|
||||
id: str
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
delta: ChoiceDelta
|
||||
logprobs: Optional[dict] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ChatComplete(BaseModel):
|
||||
# Chat completion details
|
||||
id: str
|
||||
object: str
|
||||
created: int
|
||||
model: str
|
||||
system_fingerprint: str
|
||||
choices: List[ChatCompletionComplete]
|
||||
usage: Any
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
# Model identifier
|
||||
model: str
|
||||
# List of messages in the conversation
|
||||
messages: List[Message]
|
||||
# Penalty for frequency of new tokens
|
||||
frequency_penalty: Optional[float] = None
|
||||
# Bias values for token selection
|
||||
logit_bias: Optional[List[float]] = None
|
||||
# Whether to return log probabilities
|
||||
logprobs: Optional[bool] = None
|
||||
# Number of most likely tokens to return at each position
|
||||
top_logprobs: Optional[int] = None
|
||||
# Maximum number of tokens to generate
|
||||
max_tokens: Optional[int] = None
|
||||
# Number of chat completion choices to generate
|
||||
n: Optional[int] = None
|
||||
# Penalty for presence of new tokens
|
||||
presence_penalty: Optional[float] = None
|
||||
# Flag to indicate streaming response
|
||||
stream: bool = False
|
||||
# Random sampling seed
|
||||
seed: Optional[int] = None
|
||||
# Sampling temperature
|
||||
temperature: Optional[float] = None
|
||||
# Top-p value for nucleus sampling
|
||||
top_p: Optional[float] = None
|
||||
# List of tools to be used
|
||||
tools: Optional[List[Tool]] = None
|
||||
# Choice of tool to be used
|
||||
tool_choice: Optional[str] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
# Activate logits sampling
|
||||
do_sample: bool = False
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
title: Supported Models and Hardware
|
||||
- local: messages_api
|
||||
title: Messages API
|
||||
- local: guidance
|
||||
title: Guidance
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: basic_tutorials/consuming_tgi
|
||||
|
|
|
@ -0,0 +1,419 @@
|
|||
# Guidance
|
||||
|
||||
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
|
||||
|
||||
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
|
||||
|
||||
## Quick Start
|
||||
|
||||
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
|
||||
|
||||
If you're not up to date, grab the latest version and let's get started!
|
||||
|
||||
## Table of Contents 📚
|
||||
|
||||
### Grammar and Constraints
|
||||
|
||||
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
|
||||
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
|
||||
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
|
||||
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
|
||||
|
||||
### Tools and Functions
|
||||
|
||||
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
|
||||
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
|
||||
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
## Grammar and Constraints 🛣️
|
||||
|
||||
### The Grammar Parameter
|
||||
|
||||
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
|
||||
|
||||
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
|
||||
|
||||
```json
|
||||
curl localhost:3000/generate \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
|
||||
|
||||
```
|
||||
|
||||
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
|
||||
|
||||
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
|
||||
|
||||
### Constrain with Pydantic
|
||||
|
||||
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
|
||||
|
||||
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from pydantic import BaseModel, conint
|
||||
from typing import List
|
||||
|
||||
class Animals(BaseModel):
|
||||
location: str
|
||||
activity: str
|
||||
animals_seen: conint(ge=1, le=5) # Constrained integer type
|
||||
animals: List[str]
|
||||
|
||||
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": Animals.schema()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
|
||||
|
||||
```
|
||||
|
||||
### JSON Schema Integration
|
||||
|
||||
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
json_schema = {
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string"
|
||||
},
|
||||
"activity": {
|
||||
"type": "string"
|
||||
},
|
||||
"animals_seen": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 5
|
||||
},
|
||||
"animals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["location", "activity", "animals_seen", "animals"]
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
|
||||
"parameters": {
|
||||
"max_new_tokens": 200,
|
||||
"repetition_penalty": 1.3,
|
||||
"grammar": {
|
||||
"type": "json",
|
||||
"value": json_schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:3000/generate',
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
print(response.json())
|
||||
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
|
||||
|
||||
```
|
||||
|
||||
### Using the client
|
||||
|
||||
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=1,
|
||||
grammar={
|
||||
"type": GrammarType.Regex,
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.generated_text)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# 118.8.0.84
|
||||
|
||||
```
|
||||
|
||||
## Tools and Functions 🛠️
|
||||
|
||||
### The Tools Parameter
|
||||
|
||||
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
|
||||
|
||||
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
|
||||
|
||||
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
|
||||
|
||||
```json
|
||||
curl localhost:3000/v1/chat/completions \
|
||||
-X POST \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "tgi",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in New York?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location."
|
||||
}
|
||||
},
|
||||
"required": ["location", "format"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "get_current_weather"
|
||||
}'
|
||||
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Tools used in example below</summary>
|
||||
|
||||
```python
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Text Generation Inference Client
|
||||
|
||||
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
|
||||
|
||||
```python
|
||||
from text_generation import AsyncClient
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
# Define an async function to encapsulate the async operation
|
||||
async def main():
|
||||
client = AsyncClient(base_url="http://localhost:3000")
|
||||
|
||||
# Use 'await' to wait for the async method 'chat' to complete
|
||||
response = await client.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You're a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Once the response is received, you can process it
|
||||
print(response.choices[0].message.tool_calls)
|
||||
|
||||
# Ensure the main async function is run in the event loop
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
|
||||
|
||||
```
|
||||
|
||||
### OpenAI integration
|
||||
|
||||
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
|
||||
|
||||
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Initialize the client, pointing it to one of the available models
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:3000/v1",
|
||||
api_key="_",
|
||||
)
|
||||
|
||||
# NOTE: tools defined above and removed for brevity
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model="tgi",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like the next 3 days in San Francisco, CA?",
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto", # tool selected by model
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
|
||||
called = chat_completion.choices[0].message.tool_calls
|
||||
print(called)
|
||||
# {
|
||||
# "id": 0,
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "description": None,
|
||||
# "name": "tools",
|
||||
# "parameters": {
|
||||
# "format": "celsius",
|
||||
# "location": "San Francisco, CA",
|
||||
# "num_days": 3,
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
```
|
|
@ -23,6 +23,8 @@ from text_generation.types import (
|
|||
Token,
|
||||
BestOfSequence,
|
||||
Grammar,
|
||||
ChatComplete,
|
||||
ChatCompletionChunk,
|
||||
)
|
||||
|
||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||
|
@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
) -> bool:
|
||||
def convert_data(data):
|
||||
data = json.loads(data)
|
||||
if isinstance(data, Dict) and "choices" in data:
|
||||
choices = data["choices"]
|
||||
if (
|
||||
isinstance(choices, List)
|
||||
and len(choices) >= 1
|
||||
and "delta" in choices[0]
|
||||
):
|
||||
return ChatCompletionChunk(**data)
|
||||
return ChatComplete(**data)
|
||||
|
||||
if isinstance(data, Dict):
|
||||
return Response(**data)
|
||||
|
@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
)
|
||||
)
|
||||
|
||||
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||
return (
|
||||
response.choices[0].message.content == other.choices[0].message.content
|
||||
)
|
||||
|
||||
def eq_chat_complete_chunk(
|
||||
response: ChatCompletionChunk, other: ChatCompletionChunk
|
||||
) -> bool:
|
||||
return response.choices[0].delta.content == other.choices[0].delta.content
|
||||
|
||||
def eq_response(response: Response, other: Response) -> bool:
|
||||
return response.generated_text == other.generated_text and eq_details(
|
||||
response.details, other.details
|
||||
|
@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||
if not isinstance(snapshot_data, List):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
if isinstance(serialized_data[0], ChatComplete):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
if isinstance(serialized_data[0], ChatCompletionChunk):
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[
|
||||
eq_chat_complete_chunk(r, o)
|
||||
for r, o in zip(serialized_data, snapshot_data)
|
||||
]
|
||||
)
|
||||
|
||||
return len(snapshot_data) == len(serialized_data) and all(
|
||||
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1708957015,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 100,
|
||||
"prompt_tokens": 60,
|
||||
"total_tokens": 160
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079417,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079492,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 29,
|
||||
"prompt_tokens": 316,
|
||||
"total_tokens": 345
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos_token",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": null,
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"description": null,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY"
|
||||
}
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1709079493,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native",
|
||||
"usage": {
|
||||
"completion_tokens": 21,
|
||||
"prompt_tokens": 187,
|
||||
"total_tokens": 208
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"content": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": {
|
||||
"function": {
|
||||
"arguments": "</s>",
|
||||
"name": null
|
||||
},
|
||||
"id": "",
|
||||
"index": 20,
|
||||
"type": "function"
|
||||
}
|
||||
},
|
||||
"finish_reason": "eos_token",
|
||||
"index": 20,
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"created": 1709087088,
|
||||
"id": "",
|
||||
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "1.4.2-native"
|
||||
}
|
|
@ -0,0 +1,240 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_tools_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
||||
await flash_llama_grammar_tools_handle.health(300)
|
||||
return flash_llama_grammar_tools_handle.client
|
||||
|
||||
|
||||
# tools to be used in the following tests
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_n_day_weather_forecast",
|
||||
"description": "Get an N-day weather forecast",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
"num_days": {
|
||||
"type": "integer",
|
||||
"description": "The number of days to forecast",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format", "num_days"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_no_tools(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "New York, NY",
|
||||
"num_days": 14,
|
||||
},
|
||||
},
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Brooklyn, New York?",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.choices[0].message.content == None
|
||||
assert response.choices[0].message.tool_calls == {
|
||||
"id": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": None,
|
||||
"name": "tools",
|
||||
"parameters": {"format": "celsius", "location": "New York, NY"},
|
||||
},
|
||||
}
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
flash_llama_grammar_tools, response_snapshot
|
||||
):
|
||||
responses = await flash_llama_grammar_tools.chat(
|
||||
max_tokens=100,
|
||||
seed=1,
|
||||
tools=tools,
|
||||
tool_choice="get_current_weather",
|
||||
presence_penalty=-1.1,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather like in Paris, France?",
|
||||
},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
count = 0
|
||||
async for response in responses:
|
||||
count += 1
|
||||
|
||||
assert count == 20
|
||||
assert response == response_snapshot
|
|
@ -812,23 +812,27 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -877,28 +881,33 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi again!".to_string(),
|
||||
content: Some("Hi again!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -952,23 +961,27 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
@ -1006,23 +1019,27 @@ mod tests {
|
|||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
content: Some("Hi!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
content: Some("Hello how can I help?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
content: Some("What is Deep Learning?".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
content: Some("magic!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
|
|
|
@ -358,10 +358,11 @@ impl ChatCompletion {
|
|||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
output: String,
|
||||
output: Option<String>,
|
||||
created: u64,
|
||||
details: Details,
|
||||
return_logprobs: bool,
|
||||
tool_calls: Option<ToolCall>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: String::new(),
|
||||
|
@ -375,6 +376,7 @@ impl ChatCompletion {
|
|||
role: "assistant".into(),
|
||||
content: output,
|
||||
name: None,
|
||||
tool_calls,
|
||||
},
|
||||
logprobs: return_logprobs
|
||||
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
|
||||
|
@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
|
|||
pub(crate) struct ChatCompletionDelta {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
// default to None
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<DeltaToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct DeltaToolCall {
|
||||
pub index: u32,
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||
pub(crate) struct Function {
|
||||
pub name: Option<String>,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl ChatCompletionChunk {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_fingerprint: String,
|
||||
delta: String,
|
||||
delta: Option<String>,
|
||||
tool_calls: Option<Vec<String>>,
|
||||
created: u64,
|
||||
index: u32,
|
||||
logprobs: Option<ChatCompletionLogprobs>,
|
||||
|
@ -438,6 +460,15 @@ impl ChatCompletionChunk {
|
|||
delta: ChatCompletionDelta {
|
||||
role: "assistant".to_string(),
|
||||
content: delta,
|
||||
tool_calls: tool_calls.map(|tc| DeltaToolCall {
|
||||
index,
|
||||
id: String::new(),
|
||||
r#type: "function".to_string(),
|
||||
function: Function {
|
||||
name: None,
|
||||
arguments: tc[0].to_string(),
|
||||
},
|
||||
}),
|
||||
},
|
||||
logprobs,
|
||||
finish_reason,
|
||||
|
@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
|
|||
#[serde(default)]
|
||||
#[schema(nullable = true, example = 0.95)]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||
/// functions the model may generate JSON inputs for.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// A prompt to be appended before the tools
|
||||
#[serde(default = "default_tool_prompt")]
|
||||
#[schema(
|
||||
nullable = true,
|
||||
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
|
||||
)]
|
||||
pub tool_prompt: Option<String>,
|
||||
|
||||
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
|
||||
pub tool_choice: Option<ToolType>,
|
||||
}
|
||||
|
||||
fn default_tool_prompt() -> Option<String> {
|
||||
Some(
|
||||
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
|
||||
)
|
||||
}
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
enum ToolType {
|
||||
FunctionName(String),
|
||||
OneOf,
|
||||
}
|
||||
|
||||
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
|
||||
mod deserialize_tool_choice {
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::String(s) => match s.as_str() {
|
||||
"none" => Ok(None),
|
||||
"auto" => Ok(Some(ToolType::OneOf)),
|
||||
_ => Ok(Some(ToolType::FunctionName(s))),
|
||||
},
|
||||
Value::Object(map) => {
|
||||
if let Some(content) = map
|
||||
.get("function")
|
||||
.and_then(|v| v.get("name"))
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
Ok(Some(ToolType::FunctionName(content.to_string())))
|
||||
} else {
|
||||
Err(de::Error::custom("function key not found in tool choice"))
|
||||
}
|
||||
}
|
||||
Value::Null => Ok(Some(ToolType::OneOf)),
|
||||
_ => Err(de::Error::custom("invalid token format")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub struct Tools {
|
||||
#[serde(flatten)]
|
||||
functions_map: FunctionsMap,
|
||||
properties: Properties,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionsMap {
|
||||
#[serde(rename = "$functions")]
|
||||
functions: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FunctionRef {
|
||||
#[serde(rename = "$ref")]
|
||||
ref_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Properties {
|
||||
#[serde(serialize_with = "serialize_function")]
|
||||
function: Vec<FunctionRef>,
|
||||
}
|
||||
|
||||
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("Function", 1)?;
|
||||
state.serialize_field("anyOf", functions)?;
|
||||
state.end()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct FunctionDefinition {
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub name: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Tool {
|
||||
// The type of the tool. Currently, only 'function' is supported.
|
||||
#[schema(example = "function")]
|
||||
pub r#type: String,
|
||||
// Grab the tool as generic JSON for debugging purposes.
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
|
@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
|
||||
pub(crate) struct ToolCall {
|
||||
pub id: u32,
|
||||
pub r#type: String,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
pub(crate) struct Message {
|
||||
#[schema(example = "user")]
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "My name is David and I")]
|
||||
pub content: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
#[schema(example = "\"David\"")]
|
||||
pub name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
|
|
|
@ -10,6 +10,7 @@ use crate::{
|
|||
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
|
||||
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
|
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
|
|||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
@ -581,7 +584,7 @@ async fn chat_completions(
|
|||
let seed = req.seed;
|
||||
|
||||
// apply chat template to flatten the request into a single input
|
||||
let inputs = match infer.apply_chat_template(req.messages) {
|
||||
let mut inputs = match infer.apply_chat_template(req.messages) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
|
@ -596,6 +599,62 @@ async fn chat_completions(
|
|||
}
|
||||
};
|
||||
|
||||
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
|
||||
let tool_prompt = req.tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: "Tool choice not found in tool names".to_string(),
|
||||
error_type: "Tool not found".to_string(),
|
||||
}),
|
||||
)
|
||||
})?
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
(func.name, func.parameters)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
let tools_str = serde_json::to_string(&tools).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
inputs = format!("{inputs}{tool_prompt}{tools_str}");
|
||||
Some(GrammarType::Json(serde_json::json!(tools)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
|
@ -617,7 +676,7 @@ async fn chat_completions(
|
|||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
grammar: tool_grammar.clone(),
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -640,11 +699,19 @@ async fn chat_completions(
|
|||
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
|
||||
});
|
||||
|
||||
// replace the content with the tool calls if grammar is present
|
||||
let (content, tool_calls) = if tool_grammar.is_some() {
|
||||
(None, Some(vec![stream_token.token.text]))
|
||||
} else {
|
||||
(Some(stream_token.token.text), None)
|
||||
};
|
||||
|
||||
event
|
||||
.json_data(ChatCompletionChunk::new(
|
||||
model_id.clone(),
|
||||
system_fingerprint.clone(),
|
||||
stream_token.token.text,
|
||||
content,
|
||||
tool_calls,
|
||||
current_time,
|
||||
stream_token.index,
|
||||
logprobs,
|
||||
|
@ -681,14 +748,54 @@ async fn chat_completions(
|
|||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs();
|
||||
|
||||
let (tool_calls, output) = if tool_grammar.is_some() {
|
||||
// gen_text should be valid json
|
||||
let gen_text_value: Value =
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let tool_call = Some(ToolCall {
|
||||
id: 0,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionDefinition {
|
||||
description: None,
|
||||
name: "tools".to_string(),
|
||||
parameters: gen_text_value.get("function").map_or_else(
|
||||
|| {
|
||||
serde_json::from_str(&generation.generated_text).map_err(|e| {
|
||||
(
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
error: e.to_string(),
|
||||
error_type: "Input validation error".to_string(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
},
|
||||
|f| Ok(f.clone()),
|
||||
)?,
|
||||
},
|
||||
});
|
||||
(tool_call, None)
|
||||
} else {
|
||||
(None, Some(generation.generated_text))
|
||||
};
|
||||
// build the complete response object with the full text
|
||||
let response = ChatCompletion::new(
|
||||
model_id,
|
||||
system_fingerprint,
|
||||
generation.generated_text,
|
||||
output,
|
||||
current_time,
|
||||
generation.details.unwrap(),
|
||||
logprobs,
|
||||
tool_calls,
|
||||
);
|
||||
|
||||
// wrap generation inside a Vec to match api-inference
|
||||
|
|
Loading…
Reference in New Issue