From 9b6db5f79312466ac698c128c8abd4fb3b7b47d3 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 28 Feb 2024 05:10:27 -0500 Subject: [PATCH] Support tools (#1587) This work in progress PR begins to add support for tools. Tools relies on grammar support and still has some unsolved challenges. Opening the PR for visibility and feedback --- clients/python/text_generation/client.py | 222 +++++++++- clients/python/text_generation/types.py | 120 ++++- docs/source/_toctree.yml | 2 + docs/source/guidance.md | 419 ++++++++++++++++++ integration-tests/conftest.py | 34 ++ .../test_flash_llama_grammar_no_tools.json | 26 ++ .../test_flash_llama_grammar_tools.json | 38 ++ .../test_flash_llama_grammar_tools_auto.json | 38 ++ ...test_flash_llama_grammar_tools_choice.json | 37 ++ ...test_flash_llama_grammar_tools_stream.json | 27 ++ integration-tests/models/test_tools_llama.py | 240 ++++++++++ router/src/infer.rs | 51 ++- router/src/lib.rs | 168 ++++++- router/src/server.rs | 115 ++++- 14 files changed, 1510 insertions(+), 27 deletions(-) create mode 100644 docs/source/guidance.md create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json create mode 100644 integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json create mode 100644 integration-tests/models/test_tools_llama.py diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index bbccbf1d..09660de3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -3,7 +3,7 @@ import requests from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Dict, Optional, List, AsyncIterator, Iterator +from typing import Dict, Optional, List, AsyncIterator, Iterator, Union from text_generation.types import ( StreamResponse, @@ -11,6 +11,11 @@ from text_generation.types import ( Request, Parameters, Grammar, + ChatRequest, + ChatCompletionChunk, + ChatComplete, + Message, + Tool, ) from text_generation.errors import parse_error @@ -59,6 +64,114 @@ class Client: self.cookies = cookies self.timeout = timeout + def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, + ): + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_choice (`str`): + The tool to use + + """ + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + ) + if not stream: + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, payload) + return ChatComplete(**payload) + else: + return self._chat_stream_response(request) + + def _chat_stream_response(self, request): + resp = requests.post( + f"{self.base_url}/v1/chat/completions", + json=request.dict(), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=True, + ) + # iterate and print stream + for byte_payload in resp.iter_lines(): + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + def generate( self, prompt: str, @@ -313,6 +426,113 @@ class AsyncClient: self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) + async def chat( + self, + messages: List[Message], + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + stream: bool = False, + seed: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + tools: Optional[List[Tool]] = None, + tool_choice: Optional[str] = None, + ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]: + """ + Given a list of messages, generate a response asynchronously + + Args: + messages (`List[Message]`): + List of messages + frequency_penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + logit_bias (`List[float]`): + Adjust the likelihood of specified tokens + logprobs (`bool`): + Include log probabilities in the response + top_logprobs (`int`): + Include the `n` most likely tokens at each step + max_tokens (`int`): + Maximum number of generated tokens + n (`int`): + Generate `n` completions + presence_penalty (`float`): + The parameter for presence penalty. 0.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + stream (`bool`): + Stream the response + seed (`int`): + Random sampling seed + temperature (`float`): + The value used to module the logits distribution. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation + tools (`List[Tool]`): + List of tools to use + tool_choice (`str`): + The tool to use + + """ + request = ChatRequest( + model="tgi", + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + top_logprobs=top_logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stream=stream, + seed=seed, + temperature=temperature, + top_p=top_p, + tools=tools, + tool_choice=tool_choice, + ) + if not stream: + return await self._chat_single_response(request) + else: + return self._chat_stream_response(request) + + async def _chat_single_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + payload = await resp.json() + if resp.status != 200: + raise parse_error(resp.status, payload) + return ChatComplete(**payload) + + async def _chat_stream_response(self, request): + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post( + f"{self.base_url}/v1/chat/completions", json=request.dict() + ) as resp: + async for byte_payload in resp.content: + if byte_payload == b"\n": + continue + payload = byte_payload.decode("utf-8") + if payload.startswith("data:"): + json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + try: + response = ChatCompletionChunk(**json_payload) + yield response + except ValidationError: + raise parse_error(resp.status, json_payload) + async def generate( self, prompt: str, diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 911114ee..4a308cef 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator -from typing import Optional, List, Union +from typing import Optional, List, Union, Any from text_generation.errors import ValidationError @@ -19,6 +19,124 @@ class Grammar(BaseModel): value: Union[str, dict] +class ToolCall(BaseModel): + # Id of the tool call + id: int + # Type of the tool call + type: str + # Function details of the tool call + function: dict + + +class Message(BaseModel): + # Role of the message sender + role: str + # Content of the message + content: Optional[str] + # Optional name of the message sender + name: Optional[str] = None + # Tool calls associated with the chat completion + tool_calls: Optional[Any] = None + + +class Tool(BaseModel): + # Type of the tool + type: str + # Function details of the tool + function: dict + + +class ChatCompletionComplete(BaseModel): + # Index of the chat completion + index: int + # Message associated with the chat completion + message: Message + # Log probabilities for the chat completion + logprobs: Optional[Any] + # Reason for completion + finish_reason: str + # Usage details of the chat completion + usage: Any + + +class Function(BaseModel): + name: Optional[str] + arguments: str + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: str + type: str + function: Function + + +class ChoiceDelta(BaseModel): + role: str + content: Optional[str] + tool_calls: Optional[ChoiceDeltaToolCall] + + +class Choice(BaseModel): + index: int + delta: ChoiceDelta + logprobs: Optional[dict] = None + finish_reason: Optional[str] = None + + +class ChatCompletionChunk(BaseModel): + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[Choice] + + +class ChatComplete(BaseModel): + # Chat completion details + id: str + object: str + created: int + model: str + system_fingerprint: str + choices: List[ChatCompletionComplete] + usage: Any + + +class ChatRequest(BaseModel): + # Model identifier + model: str + # List of messages in the conversation + messages: List[Message] + # Penalty for frequency of new tokens + frequency_penalty: Optional[float] = None + # Bias values for token selection + logit_bias: Optional[List[float]] = None + # Whether to return log probabilities + logprobs: Optional[bool] = None + # Number of most likely tokens to return at each position + top_logprobs: Optional[int] = None + # Maximum number of tokens to generate + max_tokens: Optional[int] = None + # Number of chat completion choices to generate + n: Optional[int] = None + # Penalty for presence of new tokens + presence_penalty: Optional[float] = None + # Flag to indicate streaming response + stream: bool = False + # Random sampling seed + seed: Optional[int] = None + # Sampling temperature + temperature: Optional[float] = None + # Top-p value for nucleus sampling + top_p: Optional[float] = None + # List of tools to be used + tools: Optional[List[Tool]] = None + # Choice of tool to be used + tool_choice: Optional[str] = None + + class Parameters(BaseModel): # Activate logits sampling do_sample: bool = False diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d57a594d..964a743a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -9,6 +9,8 @@ title: Supported Models and Hardware - local: messages_api title: Messages API + - local: guidance + title: Guidance title: Getting started - sections: - local: basic_tutorials/consuming_tgi diff --git a/docs/source/guidance.md b/docs/source/guidance.md new file mode 100644 index 00000000..8b9ba094 --- /dev/null +++ b/docs/source/guidance.md @@ -0,0 +1,419 @@ +# Guidance + +Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs. + +These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them! + +## Quick Start + +Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide. + +If you're not up to date, grab the latest version and let's get started! + +## Table of Contents 📚 + +### Grammar and Constraints + +- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision. +- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models. +- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema. +- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses. + +### Tools and Functions + +- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions. +- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions. +- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +## Grammar and Constraints 🛣️ + +### The Grammar Parameter + +In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output. + +Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability. + +```json +curl localhost:3000/generate \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park", + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] + } + } + } +}' +// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"} + +``` + +A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar. + +> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster. + +### Constrain with Pydantic + +Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format. + +Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way. + +```python +import requests +from pydantic import BaseModel, conint +from typing import List + +class Animals(BaseModel): + location: str + activity: str + animals_seen: conint(ge=1, le=5) # Constrained integer type + animals: List[str] + +prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park" + +data = { + "inputs": prompt, + "parameters": { + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": Animals.schema() + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'} + +``` + +### JSON Schema Integration + +If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control. + +```python +import requests + +json_schema = { + "properties": { + "location": { + "type": "string" + }, + "activity": { + "type": "string" + }, + "animals_seen": { + "type": "integer", + "minimum": 1, + "maximum": 5 + }, + "animals": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "required": ["location", "activity", "animals_seen", "animals"] +} + +data = { + "inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]", + "parameters": { + "max_new_tokens": 200, + "repetition_penalty": 1.3, + "grammar": { + "type": "json", + "value": json_schema + } + } +} + +headers = { + "Content-Type": "application/json", +} + +response = requests.post( + 'http://127.0.0.1:3000/generate', + headers=headers, + json=data +) +print(response.json()) +# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'} + +``` + +### Using the client + +TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter. + +```python +from text_generation import AsyncClient +from text_generation.types import GrammarType + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.generate( + "Whats Googles DNS", + max_new_tokens=10, + decoder_input_details=True, + seed=1, + grammar={ + "type": GrammarType.Regex, + "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)", + }, + ) + + # Once the response is received, you can process it + print(response.generated_text) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# 118.8.0.84 + +``` + +## Tools and Functions 🛠️ + +### The Tools Parameter + +In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API. + +Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more. + +Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API. + +```json +curl localhost:3000/v1/chat/completions \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "user", + "content": "What is the weather like in New York?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + } + }, + "required": ["location", "format"] + } + } + } + ], + "tool_choice": "get_current_weather" +}' +// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}} +``` + +
+ Tools used in example below + + ```python + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + } + ] + ``` + +
+ +### Text Generation Inference Client + +TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions. + +```python +from text_generation import AsyncClient + +# NOTE: tools defined above and removed for brevity + +# Define an async function to encapsulate the async operation +async def main(): + client = AsyncClient(base_url="http://localhost:3000") + + # Use 'await' to wait for the async method 'chat' to complete + response = await client.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "You're a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + # Once the response is received, you can process it + print(response.choices[0].message.tool_calls) + +# Ensure the main async function is run in the event loop +if __name__ == "__main__": + import asyncio + asyncio.run(main()) + +# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}} + +``` + +### OpenAI integration + +TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions. + +However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary. + +```python +from openai import OpenAI + +# Initialize the client, pointing it to one of the available models +client = OpenAI( + base_url="http://localhost:3000/v1", + api_key="_", +) + +# NOTE: tools defined above and removed for brevity + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + { + "role": "system", + "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + }, + { + "role": "user", + "content": "What's the weather like the next 3 days in San Francisco, CA?", + }, + ], + tools=tools, + tool_choice="auto", # tool selected by model + max_tokens=500, +) + + +called = chat_completion.choices[0].message.tool_calls +print(called) +# { +# "id": 0, +# "type": "function", +# "function": { +# "description": None, +# "name": "tools", +# "parameters": { +# "format": "celsius", +# "location": "San Francisco, CA", +# "num_days": 3, +# }, +# }, +# } +``` diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index e11c7cf9..96cf43ad 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -23,6 +23,8 @@ from text_generation.types import ( Token, BestOfSequence, Grammar, + ChatComplete, + ChatCompletionChunk, ) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) @@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension): ) -> bool: def convert_data(data): data = json.loads(data) + if isinstance(data, Dict) and "choices" in data: + choices = data["choices"] + if ( + isinstance(choices, List) + and len(choices) >= 1 + and "delta" in choices[0] + ): + return ChatCompletionChunk(**data) + return ChatComplete(**data) if isinstance(data, Dict): return Response(**data) @@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension): ) ) + def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: + return ( + response.choices[0].message.content == other.choices[0].message.content + ) + + def eq_chat_complete_chunk( + response: ChatCompletionChunk, other: ChatCompletionChunk + ) -> bool: + return response.choices[0].delta.content == other.choices[0].delta.content + def eq_response(response: Response, other: Response) -> bool: return response.generated_text == other.generated_text and eq_details( response.details, other.details @@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension): if not isinstance(snapshot_data, List): snapshot_data = [snapshot_data] + if isinstance(serialized_data[0], ChatComplete): + return len(snapshot_data) == len(serialized_data) and all( + [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] + ) + + if isinstance(serialized_data[0], ChatCompletionChunk): + return len(snapshot_data) == len(serialized_data) and all( + [ + eq_chat_complete_chunk(r, o) + for r, o in zip(serialized_data, snapshot_data) + ] + ) + return len(snapshot_data) == len(serialized_data) and all( [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)] ) diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json new file mode 100644 index 00000000..3c4b4aea --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json @@ -0,0 +1,26 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally", + "name": null, + "role": "assistant", + "tool_calls": null + }, + "usage": null + } + ], + "created": 1708957015, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 100, + "prompt_tokens": 60, + "total_tokens": 160 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json new file mode 100644 index 00000000..9b9e33c6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079417, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json new file mode 100644 index 00000000..de32c970 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json @@ -0,0 +1,38 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14 + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079492, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 29, + "prompt_tokens": 316, + "total_tokens": 345 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json new file mode 100644 index 00000000..3551e205 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json @@ -0,0 +1,37 @@ +{ + "choices": [ + { + "finish_reason": "eos_token", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "name": null, + "role": "assistant", + "tool_calls": { + "function": { + "description": null, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY" + } + }, + "id": 0, + "type": "function" + } + }, + "usage": null + } + ], + "created": 1709079493, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native", + "usage": { + "completion_tokens": 21, + "prompt_tokens": 187, + "total_tokens": 208 + } +} diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json new file mode 100644 index 00000000..c367cc6f --- /dev/null +++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json @@ -0,0 +1,27 @@ +{ + "choices": [ + { + "delta": { + "content": null, + "role": "assistant", + "tool_calls": { + "function": { + "arguments": "", + "name": null + }, + "id": "", + "index": 20, + "type": "function" + } + }, + "finish_reason": "eos_token", + "index": 20, + "logprobs": null + } + ], + "created": 1709087088, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "text_completion", + "system_fingerprint": "1.4.2-native" +} diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py new file mode 100644 index 00000000..38570c38 --- /dev/null +++ b/integration-tests/models/test_tools_llama.py @@ -0,0 +1,240 @@ +import pytest +import json + +from text_generation.types import GrammarType + + +@pytest.fixture(scope="module") +def flash_llama_grammar_tools_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle): + await flash_llama_grammar_tools_handle.health(300) + return flash_llama_grammar_tools_handle.client + + +# tools to be used in the following tests +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + }, + "required": ["location", "format"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location.", + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast", + }, + }, + "required": ["location", "format", "num_days"], + }, + }, + }, +] + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_no_tools( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + + assert ( + response.choices[0].message.content + == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally" + ) + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, + }, + "id": 0, + "type": "function", + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_auto( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="auto", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "function": { + "description": None, + "name": "tools", + "parameters": { + "format": "celsius", + "location": "New York, NY", + "num_days": 14, + }, + }, + "id": 0, + "type": "function", + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_choice( + flash_llama_grammar_tools, response_snapshot +): + response = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Brooklyn, New York?", + }, + ], + ) + assert response.choices[0].message.content == None + assert response.choices[0].message.tool_calls == { + "id": 0, + "type": "function", + "function": { + "description": None, + "name": "tools", + "parameters": {"format": "celsius", "location": "New York, NY"}, + }, + } + assert response == response_snapshot + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_grammar_tools_stream( + flash_llama_grammar_tools, response_snapshot +): + responses = await flash_llama_grammar_tools.chat( + max_tokens=100, + seed=1, + tools=tools, + tool_choice="get_current_weather", + presence_penalty=-1.1, + messages=[ + { + "role": "system", + "content": "Youre a helpful assistant! Answer the users question best you can.", + }, + { + "role": "user", + "content": "What is the weather like in Paris, France?", + }, + ], + stream=True, + ) + + count = 0 + async for response in responses: + count += 1 + + assert count == 20 + assert response == response_snapshot diff --git a/router/src/infer.rs b/router/src/infer.rs index 472b7d66..42405327 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -812,23 +812,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -877,28 +881,33 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "Hi again!".to_string(), + content: Some("Hi again!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -952,23 +961,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), @@ -1006,23 +1019,27 @@ mod tests { messages: vec![ Message { role: "user".to_string(), - content: "Hi!".to_string(), + content: Some("Hi!".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "Hello how can I help?".to_string(), + content: Some("Hello how can I help?".to_string()), name: None, + tool_calls: None, }, Message { role: "user".to_string(), - content: "What is Deep Learning?".to_string(), + content: Some("What is Deep Learning?".to_string()), name: None, + tool_calls: None, }, Message { role: "assistant".to_string(), - content: "magic!".to_string(), + content: Some("magic!".to_string()), name: None, + tool_calls: None, }, ], bos_token: Some("[BOS]"), diff --git a/router/src/lib.rs b/router/src/lib.rs index 1c06eb8a..d89bacb5 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -358,10 +358,11 @@ impl ChatCompletion { pub(crate) fn new( model: String, system_fingerprint: String, - output: String, + output: Option, created: u64, details: Details, return_logprobs: bool, + tool_calls: Option, ) -> Self { Self { id: String::new(), @@ -375,6 +376,7 @@ impl ChatCompletion { role: "assistant".into(), content: output, name: None, + tool_calls, }, logprobs: return_logprobs .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), @@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice { pub(crate) struct ChatCompletionDelta { #[schema(example = "user")] pub role: String, + #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "What is Deep Learning?")] - pub content: String, + pub content: Option, + // default to None + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct DeltaToolCall { + pub index: u32, + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] +pub(crate) struct Function { + pub name: Option, + pub arguments: String, +} + +#[allow(clippy::too_many_arguments)] impl ChatCompletionChunk { pub(crate) fn new( model: String, system_fingerprint: String, - delta: String, + delta: Option, + tool_calls: Option>, created: u64, index: u32, logprobs: Option, @@ -438,6 +460,15 @@ impl ChatCompletionChunk { delta: ChatCompletionDelta { role: "assistant".to_string(), content: delta, + tool_calls: tool_calls.map(|tc| DeltaToolCall { + index, + id: String::new(), + r#type: "function".to_string(), + function: Function { + name: None, + arguments: tc[0].to_string(), + }, + }), }, logprobs, finish_reason, @@ -520,6 +551,125 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, example = 0.95)] pub top_p: Option, + + /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of + /// functions the model may generate JSON inputs for. + #[serde(default)] + #[schema(nullable = true, example = "null")] + pub tools: Option>, + + /// A prompt to be appended before the tools + #[serde(default = "default_tool_prompt")] + #[schema( + nullable = true, + example = "\"Based on the conversation, please choose the most appropriate tool to use: \"" + )] + pub tool_prompt: Option, + + /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. + #[serde(default)] + #[schema(nullable = true, example = "null")] + #[serde(deserialize_with = "deserialize_tool_choice::deserialize")] + pub tool_choice: Option, +} + +fn default_tool_prompt() -> Option { + Some( + "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(), + ) +} +#[derive(Clone, Deserialize, ToSchema, Serialize)] +enum ToolType { + FunctionName(String), + OneOf, +} + +/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None) +mod deserialize_tool_choice { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + match value { + Value::String(s) => match s.as_str() { + "none" => Ok(None), + "auto" => Ok(Some(ToolType::OneOf)), + _ => Ok(Some(ToolType::FunctionName(s))), + }, + Value::Object(map) => { + if let Some(content) = map + .get("function") + .and_then(|v| v.get("name")) + .and_then(|v| v.as_str()) + { + Ok(Some(ToolType::FunctionName(content.to_string()))) + } else { + Err(de::Error::custom("function key not found in tool choice")) + } + } + Value::Null => Ok(Some(ToolType::OneOf)), + _ => Err(de::Error::custom("invalid token format")), + } + } +} + +#[derive(Debug, Deserialize, Serialize, ToSchema)] +pub struct Tools { + #[serde(flatten)] + functions_map: FunctionsMap, + properties: Properties, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FunctionsMap { + #[serde(rename = "$functions")] + functions: std::collections::HashMap, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FunctionRef { + #[serde(rename = "$ref")] + ref_path: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Properties { + #[serde(serialize_with = "serialize_function")] + function: Vec, +} + +fn serialize_function(functions: &Vec, serializer: S) -> Result +where + S: serde::Serializer, +{ + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("Function", 1)?; + state.serialize_field("anyOf", functions)?; + state.end() +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] +pub(crate) struct FunctionDefinition { + #[serde(default)] + pub description: Option, + pub name: String, + pub parameters: serde_json::Value, +} + +#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] +pub(crate) struct Tool { + // The type of the tool. Currently, only 'function' is supported. + #[schema(example = "function")] + pub r#type: String, + // Grab the tool as generic JSON for debugging purposes. + pub function: FunctionDefinition, } #[derive(Clone, Serialize, Deserialize)] @@ -530,15 +680,25 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, } +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +pub(crate) struct ToolCall { + pub id: u32, + pub r#type: String, + pub function: FunctionDefinition, +} + #[derive(Clone, Deserialize, ToSchema, Serialize)] pub(crate) struct Message { #[schema(example = "user")] pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] - pub content: String, + pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")] pub name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, } #[derive(Clone, Debug, Deserialize, ToSchema)] diff --git a/router/src/server.rs b/router/src/server.rs index 9fdd66cc..2efa9284 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,6 +10,7 @@ use crate::{ HubTokenizerConfig, Infer, Info, Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage, Validation, VertexRequest, VertexResponse, }; +use crate::{FunctionDefinition, FunctionRef, FunctionsMap, Properties, ToolCall, ToolType, Tools}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -22,6 +23,8 @@ use futures::stream::StreamExt; use futures::Stream; use futures::TryStreamExt; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; +use serde_json::Value; +use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -581,7 +584,7 @@ async fn chat_completions( let seed = req.seed; // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(req.messages) { + let mut inputs = match infer.apply_chat_template(req.messages) { Ok(inputs) => inputs, Err(err) => { metrics::increment_counter!("tgi_request_failure", "err" => "validation"); @@ -596,6 +599,62 @@ async fn chat_completions( } }; + let tool_grammar = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) { + let tool_prompt = req.tool_prompt.unwrap_or_default(); + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![req_tools + .iter() + .find(|tool| tool.function.name == *name) + .ok_or_else(|| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: "Tool choice not found in tool names".to_string(), + error_type: "Tool not found".to_string(), + }), + ) + })? + .clone()] + } + ToolType::OneOf => req_tools.to_owned(), + }; + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + (func.name, func.parameters) + }) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .collect(), + }, + }; + + let tools_str = serde_json::to_string(&tools).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + inputs = format!("{inputs}{tool_prompt}{tools_str}"); + Some(GrammarType::Json(serde_json::json!(tools))) + } else { + None + }; + // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), @@ -617,7 +676,7 @@ async fn chat_completions( decoder_input_details: !stream, seed, top_n_tokens: None, - grammar: None, + grammar: tool_grammar.clone(), }, }; @@ -640,11 +699,19 @@ async fn chat_completions( ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens)) }); + // replace the content with the tool calls if grammar is present + let (content, tool_calls) = if tool_grammar.is_some() { + (None, Some(vec![stream_token.token.text])) + } else { + (Some(stream_token.token.text), None) + }; + event .json_data(ChatCompletionChunk::new( model_id.clone(), system_fingerprint.clone(), - stream_token.token.text, + content, + tool_calls, current_time, stream_token.index, logprobs, @@ -681,14 +748,54 @@ async fn chat_completions( .unwrap_or_else(|_| std::time::Duration::from_secs(0)) .as_secs(); + let (tool_calls, output) = if tool_grammar.is_some() { + // gen_text should be valid json + let gen_text_value: Value = + serde_json::from_str(&generation.generated_text).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + })?; + + let tool_call = Some(ToolCall { + id: 0, + r#type: "function".to_string(), + function: FunctionDefinition { + description: None, + name: "tools".to_string(), + parameters: gen_text_value.get("function").map_or_else( + || { + serde_json::from_str(&generation.generated_text).map_err(|e| { + ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(ErrorResponse { + error: e.to_string(), + error_type: "Input validation error".to_string(), + }), + ) + }) + }, + |f| Ok(f.clone()), + )?, + }, + }); + (tool_call, None) + } else { + (None, Some(generation.generated_text)) + }; // build the complete response object with the full text let response = ChatCompletion::new( model_id, system_fingerprint, - generation.generated_text, + output, current_time, generation.details.unwrap(), logprobs, + tool_calls, ); // wrap generation inside a Vec to match api-inference