Support tools (#1587)

This work in progress PR begins to add support for tools. Tools relies
on grammar support and still has some unsolved challenges. Opening the
PR for visibility and feedback
This commit is contained in:
drbh 2024-02-28 05:10:27 -05:00 committed by GitHub
parent bf700e7eef
commit 9b6db5f793
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1510 additions and 27 deletions

View File

@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import ( from text_generation.types import (
StreamResponse, StreamResponse,
@ -11,6 +11,11 @@ from text_generation.types import (
Request, Request,
Parameters, Parameters,
Grammar, Grammar,
ChatRequest,
ChatCompletionChunk,
ChatComplete,
Message,
Tool,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
@ -59,6 +64,114 @@ class Client:
self.cookies = cookies self.cookies = cookies
self.timeout = timeout self.timeout = timeout
def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
):
"""
Given a list of messages, generate a response asynchronously
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
Include log probabilities in the response
top_logprobs (`int`):
Include the `n` most likely tokens at each step
max_tokens (`int`):
Maximum number of generated tokens
n (`int`):
Generate `n` completions
presence_penalty (`float`):
The parameter for presence penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
stream (`bool`):
Stream the response
seed (`int`):
Random sampling seed
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
if not stream:
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
return ChatComplete(**payload)
else:
return self._chat_stream_response(request)
def _chat_stream_response(self, request):
resp = requests.post(
f"{self.base_url}/v1/chat/completions",
json=request.dict(),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
stream=True,
)
# iterate and print stream
for byte_payload in resp.iter_lines():
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
def generate( def generate(
self, self,
prompt: str, prompt: str,
@ -313,6 +426,113 @@ class AsyncClient:
self.cookies = cookies self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60) self.timeout = ClientTimeout(timeout * 60)
async def chat(
self,
messages: List[Message],
frequency_penalty: Optional[float] = None,
logit_bias: Optional[List[float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
stream: bool = False,
seed: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
tools: Optional[List[Tool]] = None,
tool_choice: Optional[str] = None,
) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
"""
Given a list of messages, generate a response asynchronously
Args:
messages (`List[Message]`):
List of messages
frequency_penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
logit_bias (`List[float]`):
Adjust the likelihood of specified tokens
logprobs (`bool`):
Include log probabilities in the response
top_logprobs (`int`):
Include the `n` most likely tokens at each step
max_tokens (`int`):
Maximum number of generated tokens
n (`int`):
Generate `n` completions
presence_penalty (`float`):
The parameter for presence penalty. 0.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
stream (`bool`):
Stream the response
seed (`int`):
Random sampling seed
temperature (`float`):
The value used to module the logits distribution.
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation
tools (`List[Tool]`):
List of tools to use
tool_choice (`str`):
The tool to use
"""
request = ChatRequest(
model="tgi",
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stream=stream,
seed=seed,
temperature=temperature,
top_p=top_p,
tools=tools,
tool_choice=tool_choice,
)
if not stream:
return await self._chat_single_response(request)
else:
return self._chat_stream_response(request)
async def _chat_single_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
payload = await resp.json()
if resp.status != 200:
raise parse_error(resp.status, payload)
return ChatComplete(**payload)
async def _chat_stream_response(self, request):
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(
f"{self.base_url}/v1/chat/completions", json=request.dict()
) as resp:
async for byte_payload in resp.content:
if byte_payload == b"\n":
continue
payload = byte_payload.decode("utf-8")
if payload.startswith("data:"):
json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
try:
response = ChatCompletionChunk(**json_payload)
yield response
except ValidationError:
raise parse_error(resp.status, json_payload)
async def generate( async def generate(
self, self,
prompt: str, prompt: str,

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List, Union from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
@ -19,6 +19,124 @@ class Grammar(BaseModel):
value: Union[str, dict] value: Union[str, dict]
class ToolCall(BaseModel):
# Id of the tool call
id: int
# Type of the tool call
type: str
# Function details of the tool call
function: dict
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
tool_calls: Optional[Any] = None
class Tool(BaseModel):
# Type of the tool
type: str
# Function details of the tool
function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
class Function(BaseModel):
name: Optional[str]
arguments: str
class ChoiceDeltaToolCall(BaseModel):
index: int
id: str
type: str
function: Function
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
tool_calls: Optional[ChoiceDeltaToolCall]
class Choice(BaseModel):
index: int
delta: ChoiceDelta
logprobs: Optional[dict] = None
finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatRequest(BaseModel):
# Model identifier
model: str
# List of messages in the conversation
messages: List[Message]
# Penalty for frequency of new tokens
frequency_penalty: Optional[float] = None
# Bias values for token selection
logit_bias: Optional[List[float]] = None
# Whether to return log probabilities
logprobs: Optional[bool] = None
# Number of most likely tokens to return at each position
top_logprobs: Optional[int] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Number of chat completion choices to generate
n: Optional[int] = None
# Penalty for presence of new tokens
presence_penalty: Optional[float] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# List of tools to be used
tools: Optional[List[Tool]] = None
# Choice of tool to be used
tool_choice: Optional[str] = None
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False

View File

@ -9,6 +9,8 @@
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api - local: messages_api
title: Messages API title: Messages API
- local: guidance
title: Guidance
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi

419
docs/source/guidance.md Normal file
View File

@ -0,0 +1,419 @@
# Guidance
Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
## Quick Start
Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
If you're not up to date, grab the latest version and let's get started!
## Table of Contents 📚
### Grammar and Constraints
- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
### Tools and Functions
- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
## Grammar and Constraints 🛣️
### The Grammar Parameter
In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
```json
curl localhost:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
"parameters": {
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": {
"properties": {
"location": {
"type": "string"
},
"activity": {
"type": "string"
},
"animals_seen": {
"type": "integer",
"minimum": 1,
"maximum": 5
},
"animals": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["location", "activity", "animals_seen", "animals"]
}
}
}
}'
// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
```
A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
### Constrain with Pydantic
Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
```python
import requests
from pydantic import BaseModel, conint
from typing import List
class Animals(BaseModel):
location: str
activity: str
animals_seen: conint(ge=1, le=5) # Constrained integer type
animals: List[str]
prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
data = {
"inputs": prompt,
"parameters": {
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": Animals.schema()
}
}
}
headers = {
"Content-Type": "application/json",
}
response = requests.post(
'http://127.0.0.1:3000/generate',
headers=headers,
json=data
)
print(response.json())
# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
```
### JSON Schema Integration
If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
```python
import requests
json_schema = {
"properties": {
"location": {
"type": "string"
},
"activity": {
"type": "string"
},
"animals_seen": {
"type": "integer",
"minimum": 1,
"maximum": 5
},
"animals": {
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["location", "activity", "animals_seen", "animals"]
}
data = {
"inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
"parameters": {
"max_new_tokens": 200,
"repetition_penalty": 1.3,
"grammar": {
"type": "json",
"value": json_schema
}
}
}
headers = {
"Content-Type": "application/json",
}
response = requests.post(
'http://127.0.0.1:3000/generate',
headers=headers,
json=data
)
print(response.json())
# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
```
### Using the client
TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
```python
from text_generation import AsyncClient
from text_generation.types import GrammarType
# NOTE: tools defined above and removed for brevity
# Define an async function to encapsulate the async operation
async def main():
client = AsyncClient(base_url="http://localhost:3000")
# Use 'await' to wait for the async method 'chat' to complete
response = await client.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=1,
grammar={
"type": GrammarType.Regex,
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
# Once the response is received, you can process it
print(response.generated_text)
# Ensure the main async function is run in the event loop
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# 118.8.0.84
```
## Tools and Functions 🛠️
### The Tools Parameter
In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
```json
curl localhost:3000/v1/chat/completions \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is the weather like in New York?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": ["location", "format"]
}
}
}
],
"tool_choice": "get_current_weather"
}'
// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
```
<details>
<summary>Tools used in example below</summary>
```python
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
},
},
}
]
```
</details>
### Text Generation Inference Client
TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
```python
from text_generation import AsyncClient
# NOTE: tools defined above and removed for brevity
# Define an async function to encapsulate the async operation
async def main():
client = AsyncClient(base_url="http://localhost:3000")
# Use 'await' to wait for the async method 'chat' to complete
response = await client.chat(
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "You're a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
# Once the response is received, you can process it
print(response.choices[0].message.tool_calls)
# Ensure the main async function is run in the event loop
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
```
### OpenAI integration
TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
```python
from openai import OpenAI
# Initialize the client, pointing it to one of the available models
client = OpenAI(
base_url="http://localhost:3000/v1",
api_key="_",
)
# NOTE: tools defined above and removed for brevity
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{
"role": "system",
"content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
},
{
"role": "user",
"content": "What's the weather like the next 3 days in San Francisco, CA?",
},
],
tools=tools,
tool_choice="auto", # tool selected by model
max_tokens=500,
)
called = chat_completion.choices[0].message.tool_calls
print(called)
# {
# "id": 0,
# "type": "function",
# "function": {
# "description": None,
# "name": "tools",
# "parameters": {
# "format": "celsius",
# "location": "San Francisco, CA",
# "num_days": 3,
# },
# },
# }
```

View File

@ -23,6 +23,8 @@ from text_generation.types import (
Token, Token,
BestOfSequence, BestOfSequence,
Grammar, Grammar,
ChatComplete,
ChatCompletionChunk,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool: ) -> bool:
def convert_data(data): def convert_data(data):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data:
choices = data["choices"]
if (
isinstance(choices, List)
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data)
return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return (
response.choices[0].message.content == other.choices[0].message.content
)
def eq_chat_complete_chunk(
response: ChatCompletionChunk, other: ChatCompletionChunk
) -> bool:
return response.choices[0].delta.content == other.choices[0].delta.content
def eq_response(response: Response, other: Response) -> bool: def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details( return response.generated_text == other.generated_text and eq_details(
response.details, other.details response.details, other.details
@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List): if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data] snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatCompletionChunk):
return len(snapshot_data) == len(serialized_data) and all(
[
eq_chat_complete_chunk(r, o)
for r, o in zip(serialized_data, snapshot_data)
]
)
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
) )

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1708957015,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 100,
"prompt_tokens": 60,
"total_tokens": 160
}
}

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079417,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079492,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 29,
"prompt_tokens": 316,
"total_tokens": 345
}
}

View File

@ -0,0 +1,37 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 0,
"logprobs": null,
"message": {
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"function": {
"description": null,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY"
}
},
"id": 0,
"type": "function"
}
},
"usage": null
}
],
"created": 1709079493,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native",
"usage": {
"completion_tokens": 21,
"prompt_tokens": 187,
"total_tokens": 208
}
}

View File

@ -0,0 +1,27 @@
{
"choices": [
{
"delta": {
"content": null,
"role": "assistant",
"tool_calls": {
"function": {
"arguments": "</s>",
"name": null
},
"id": "",
"index": 20,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"logprobs": null
}
],
"created": 1709087088,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "1.4.2-native"
}

View File

@ -0,0 +1,240 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
await flash_llama_grammar_tools_handle.health(300)
return flash_llama_grammar_tools_handle.client
# tools to be used in the following tests
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
},
{
"type": "function",
"function": {
"name": "get_n_day_weather_forecast",
"description": "Get an N-day weather forecast",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
"num_days": {
"type": "integer",
"description": "The number of days to forecast",
},
},
"required": ["location", "format", "num_days"],
},
},
},
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_no_tools(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert (
response.choices[0].message.content
== "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="auto",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"function": {
"description": None,
"name": "tools",
"parameters": {
"format": "celsius",
"location": "New York, NY",
"num_days": 14,
},
},
"id": 0,
"type": "function",
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
flash_llama_grammar_tools, response_snapshot
):
response = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Brooklyn, New York?",
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
"id": 0,
"type": "function",
"function": {
"description": None,
"name": "tools",
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
flash_llama_grammar_tools, response_snapshot
):
responses = await flash_llama_grammar_tools.chat(
max_tokens=100,
seed=1,
tools=tools,
tool_choice="get_current_weather",
presence_penalty=-1.1,
messages=[
{
"role": "system",
"content": "Youre a helpful assistant! Answer the users question best you can.",
},
{
"role": "user",
"content": "What is the weather like in Paris, France?",
},
],
stream=True,
)
count = 0
async for response in responses:
count += 1
assert count == 20
assert response == response_snapshot

View File

@ -812,23 +812,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -877,28 +881,33 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi again!".to_string(), content: Some("Hi again!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -952,23 +961,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1006,23 +1019,27 @@ mod tests {
messages: vec![ messages: vec![
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "Hi!".to_string(), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "Hello how can I help?".to_string(), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: "What is Deep Learning?".to_string(), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: "magic!".to_string(), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),

View File

@ -358,10 +358,11 @@ impl ChatCompletion {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
output: String, output: Option<String>,
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<ToolCall>,
) -> Self { ) -> Self {
Self { Self {
id: String::new(), id: String::new(),
@ -375,6 +376,7 @@ impl ChatCompletion {
role: "assistant".into(), role: "assistant".into(),
content: output, content: output,
name: None, name: None,
tool_calls,
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
pub(crate) struct ChatCompletionDelta { pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub content: String, pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct DeltaToolCall {
pub index: u32,
pub id: String,
pub r#type: String,
pub function: Function,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub(crate) struct Function {
pub name: Option<String>,
pub arguments: String,
}
#[allow(clippy::too_many_arguments)]
impl ChatCompletionChunk { impl ChatCompletionChunk {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_fingerprint: String, system_fingerprint: String,
delta: String, delta: Option<String>,
tool_calls: Option<Vec<String>>,
created: u64, created: u64,
index: u32, index: u32,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
@ -438,6 +460,15 @@ impl ChatCompletionChunk {
delta: ChatCompletionDelta { delta: ChatCompletionDelta {
role: "assistant".to_string(), role: "assistant".to_string(),
content: delta, content: delta,
tool_calls: tool_calls.map(|tc| DeltaToolCall {
index,
id: String::new(),
r#type: "function".to_string(),
function: Function {
name: None,
arguments: tc[0].to_string(),
},
}),
}, },
logprobs, logprobs,
finish_reason, finish_reason,
@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = 0.95)] #[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>, pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default = "default_tool_prompt")]
#[schema(
nullable = true,
example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
pub tool_choice: Option<ToolType>,
}
fn default_tool_prompt() -> Option<String> {
Some(
"\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
)
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
enum ToolType {
FunctionName(String),
OneOf,
}
/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
mod deserialize_tool_choice {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<ToolType>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => match s.as_str() {
"none" => Ok(None),
"auto" => Ok(Some(ToolType::OneOf)),
_ => Ok(Some(ToolType::FunctionName(s))),
},
Value::Object(map) => {
if let Some(content) = map
.get("function")
.and_then(|v| v.get("name"))
.and_then(|v| v.as_str())
{
Ok(Some(ToolType::FunctionName(content.to_string())))
} else {
Err(de::Error::custom("function key not found in tool choice"))
}
}
Value::Null => Ok(Some(ToolType::OneOf)),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct Tools {
#[serde(flatten)]
functions_map: FunctionsMap,
properties: Properties,
}
#[derive(Debug, Serialize, Deserialize)]
struct FunctionsMap {
#[serde(rename = "$functions")]
functions: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize)]
struct FunctionRef {
#[serde(rename = "$ref")]
ref_path: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct Properties {
#[serde(serialize_with = "serialize_function")]
function: Vec<FunctionRef>,
}
fn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Function", 1)?;
state.serialize_field("anyOf", functions)?;
state.end()
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct FunctionDefinition {
#[serde(default)]
pub description: Option<String>,
pub name: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]
pub r#type: String,
// Grab the tool as generic JSON for debugging purposes.
pub function: FunctionDefinition,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> {
add_generation_prompt: bool, add_generation_prompt: bool,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall {
pub id: u32,
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
pub content: String, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -10,6 +10,7 @@ use crate::{
HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails,
StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools};
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -22,6 +23,8 @@ use futures::stream::StreamExt;
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
@ -581,7 +584,7 @@ async fn chat_completions(
let seed = req.seed; let seed = req.seed;
// apply chat template to flatten the request into a single input // apply chat template to flatten the request into a single input
let inputs = match infer.apply_chat_template(req.messages) { let mut inputs = match infer.apply_chat_template(req.messages) {
Ok(inputs) => inputs, Ok(inputs) => inputs,
Err(err) => { Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation"); metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@ -596,6 +599,62 @@ async fn chat_completions(
} }
}; };
let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
let tool_prompt = req.tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.ok_or_else(|| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Tool choice not found in tool names".to_string(),
error_type: "Tool not found".to_string(),
}),
)
})?
.clone()]
}
ToolType::OneOf => req_tools.to_owned(),
};
let functions: HashMap<String, Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
(func.name, func.parameters)
})
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.collect(),
},
};
let tools_str = serde_json::to_string(&tools).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
inputs = format!("{inputs}{tool_prompt}{tools_str}");
Some(GrammarType::Json(serde_json::json!(tools)))
} else {
None
};
// build the request passing some parameters // build the request passing some parameters
let generate_request = GenerateRequest { let generate_request = GenerateRequest {
inputs: inputs.to_string(), inputs: inputs.to_string(),
@ -617,7 +676,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: tool_grammar.clone(),
}, },
}; };
@ -640,11 +699,19 @@ async fn chat_completions(
ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
}); });
// replace the content with the tool calls if grammar is present
let (content, tool_calls) = if tool_grammar.is_some() {
(None, Some(vec![stream_token.token.text]))
} else {
(Some(stream_token.token.text), None)
};
event event
.json_data(ChatCompletionChunk::new( .json_data(ChatCompletionChunk::new(
model_id.clone(), model_id.clone(),
system_fingerprint.clone(), system_fingerprint.clone(),
stream_token.token.text, content,
tool_calls,
current_time, current_time,
stream_token.index, stream_token.index,
logprobs, logprobs,
@ -681,14 +748,54 @@ async fn chat_completions(
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let (tool_calls, output) = if tool_grammar.is_some() {
// gen_text should be valid json
let gen_text_value: Value =
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})?;
let tool_call = Some(ToolCall {
id: 0,
r#type: "function".to_string(),
function: FunctionDefinition {
description: None,
name: "tools".to_string(),
parameters: gen_text_value.get("function").map_or_else(
|| {
serde_json::from_str(&generation.generated_text).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: e.to_string(),
error_type: "Input validation error".to_string(),
}),
)
})
},
|f| Ok(f.clone()),
)?,
},
});
(tool_call, None)
} else {
(None, Some(generation.generated_text))
};
// build the complete response object with the full text // build the complete response object with the full text
let response = ChatCompletion::new( let response = ChatCompletion::new(
model_id, model_id,
system_fingerprint, system_fingerprint,
generation.generated_text, output,
current_time, current_time,
generation.details.unwrap(), generation.details.unwrap(),
logprobs, logprobs,
tool_calls,
); );
// wrap generation inside a Vec to match api-inference // wrap generation inside a Vec to match api-inference