diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py
index bbccbf1d..09660de3 100644
--- a/clients/python/text_generation/client.py
+++ b/clients/python/text_generation/client.py
@@ -3,7 +3,7 @@ import requests
from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
-from typing import Dict, Optional, List, AsyncIterator, Iterator
+from typing import Dict, Optional, List, AsyncIterator, Iterator, Union
from text_generation.types import (
StreamResponse,
@@ -11,6 +11,11 @@ from text_generation.types import (
Request,
Parameters,
Grammar,
+ ChatRequest,
+ ChatCompletionChunk,
+ ChatComplete,
+ Message,
+ Tool,
)
from text_generation.errors import parse_error
@@ -59,6 +64,114 @@ class Client:
self.cookies = cookies
self.timeout = timeout
+ def chat(
+ self,
+ messages: List[Message],
+ frequency_penalty: Optional[float] = None,
+ logit_bias: Optional[List[float]] = None,
+ logprobs: Optional[bool] = None,
+ top_logprobs: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[float] = None,
+ stream: bool = False,
+ seed: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+ tools: Optional[List[Tool]] = None,
+ tool_choice: Optional[str] = None,
+ ):
+ """
+ Given a list of messages, generate a response asynchronously
+
+ Args:
+ messages (`List[Message]`):
+ List of messages
+ frequency_penalty (`float`):
+ The parameter for frequency penalty. 0.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ logit_bias (`List[float]`):
+ Adjust the likelihood of specified tokens
+ logprobs (`bool`):
+ Include log probabilities in the response
+ top_logprobs (`int`):
+ Include the `n` most likely tokens at each step
+ max_tokens (`int`):
+ Maximum number of generated tokens
+ n (`int`):
+ Generate `n` completions
+ presence_penalty (`float`):
+ The parameter for presence penalty. 0.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ stream (`bool`):
+ Stream the response
+ seed (`int`):
+ Random sampling seed
+ temperature (`float`):
+ The value used to module the logits distribution.
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
+ higher are kept for generation
+ tools (`List[Tool]`):
+ List of tools to use
+ tool_choice (`str`):
+ The tool to use
+
+ """
+ request = ChatRequest(
+ model="tgi",
+ messages=messages,
+ frequency_penalty=frequency_penalty,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ top_logprobs=top_logprobs,
+ max_tokens=max_tokens,
+ n=n,
+ presence_penalty=presence_penalty,
+ stream=stream,
+ seed=seed,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ tool_choice=tool_choice,
+ )
+ if not stream:
+ resp = requests.post(
+ f"{self.base_url}/v1/chat/completions",
+ json=request.dict(),
+ headers=self.headers,
+ cookies=self.cookies,
+ timeout=self.timeout,
+ )
+ payload = resp.json()
+ if resp.status_code != 200:
+ raise parse_error(resp.status_code, payload)
+ return ChatComplete(**payload)
+ else:
+ return self._chat_stream_response(request)
+
+ def _chat_stream_response(self, request):
+ resp = requests.post(
+ f"{self.base_url}/v1/chat/completions",
+ json=request.dict(),
+ headers=self.headers,
+ cookies=self.cookies,
+ timeout=self.timeout,
+ stream=True,
+ )
+ # iterate and print stream
+ for byte_payload in resp.iter_lines():
+ if byte_payload == b"\n":
+ continue
+ payload = byte_payload.decode("utf-8")
+ if payload.startswith("data:"):
+ json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
+ try:
+ response = ChatCompletionChunk(**json_payload)
+ yield response
+ except ValidationError:
+ raise parse_error(resp.status, json_payload)
+
def generate(
self,
prompt: str,
@@ -313,6 +426,113 @@ class AsyncClient:
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
+ async def chat(
+ self,
+ messages: List[Message],
+ frequency_penalty: Optional[float] = None,
+ logit_bias: Optional[List[float]] = None,
+ logprobs: Optional[bool] = None,
+ top_logprobs: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ n: Optional[int] = None,
+ presence_penalty: Optional[float] = None,
+ stream: bool = False,
+ seed: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_p: Optional[float] = None,
+ tools: Optional[List[Tool]] = None,
+ tool_choice: Optional[str] = None,
+ ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:
+ """
+ Given a list of messages, generate a response asynchronously
+
+ Args:
+ messages (`List[Message]`):
+ List of messages
+ frequency_penalty (`float`):
+ The parameter for frequency penalty. 0.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ logit_bias (`List[float]`):
+ Adjust the likelihood of specified tokens
+ logprobs (`bool`):
+ Include log probabilities in the response
+ top_logprobs (`int`):
+ Include the `n` most likely tokens at each step
+ max_tokens (`int`):
+ Maximum number of generated tokens
+ n (`int`):
+ Generate `n` completions
+ presence_penalty (`float`):
+ The parameter for presence penalty. 0.0 means no penalty. See [this
+ paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
+ stream (`bool`):
+ Stream the response
+ seed (`int`):
+ Random sampling seed
+ temperature (`float`):
+ The value used to module the logits distribution.
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
+ higher are kept for generation
+ tools (`List[Tool]`):
+ List of tools to use
+ tool_choice (`str`):
+ The tool to use
+
+ """
+ request = ChatRequest(
+ model="tgi",
+ messages=messages,
+ frequency_penalty=frequency_penalty,
+ logit_bias=logit_bias,
+ logprobs=logprobs,
+ top_logprobs=top_logprobs,
+ max_tokens=max_tokens,
+ n=n,
+ presence_penalty=presence_penalty,
+ stream=stream,
+ seed=seed,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ tool_choice=tool_choice,
+ )
+ if not stream:
+ return await self._chat_single_response(request)
+ else:
+ return self._chat_stream_response(request)
+
+ async def _chat_single_response(self, request):
+ async with ClientSession(
+ headers=self.headers, cookies=self.cookies, timeout=self.timeout
+ ) as session:
+ async with session.post(
+ f"{self.base_url}/v1/chat/completions", json=request.dict()
+ ) as resp:
+ payload = await resp.json()
+ if resp.status != 200:
+ raise parse_error(resp.status, payload)
+ return ChatComplete(**payload)
+
+ async def _chat_stream_response(self, request):
+ async with ClientSession(
+ headers=self.headers, cookies=self.cookies, timeout=self.timeout
+ ) as session:
+ async with session.post(
+ f"{self.base_url}/v1/chat/completions", json=request.dict()
+ ) as resp:
+ async for byte_payload in resp.content:
+ if byte_payload == b"\n":
+ continue
+ payload = byte_payload.decode("utf-8")
+ if payload.startswith("data:"):
+ json_payload = json.loads(payload.lstrip("data:").rstrip("\n"))
+ try:
+ response = ChatCompletionChunk(**json_payload)
+ yield response
+ except ValidationError:
+ raise parse_error(resp.status, json_payload)
+
async def generate(
self,
prompt: str,
diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py
index 911114ee..4a308cef 100644
--- a/clients/python/text_generation/types.py
+++ b/clients/python/text_generation/types.py
@@ -1,6 +1,6 @@
from enum import Enum
from pydantic import BaseModel, validator
-from typing import Optional, List, Union
+from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError
@@ -19,6 +19,124 @@ class Grammar(BaseModel):
value: Union[str, dict]
+class ToolCall(BaseModel):
+ # Id of the tool call
+ id: int
+ # Type of the tool call
+ type: str
+ # Function details of the tool call
+ function: dict
+
+
+class Message(BaseModel):
+ # Role of the message sender
+ role: str
+ # Content of the message
+ content: Optional[str]
+ # Optional name of the message sender
+ name: Optional[str] = None
+ # Tool calls associated with the chat completion
+ tool_calls: Optional[Any] = None
+
+
+class Tool(BaseModel):
+ # Type of the tool
+ type: str
+ # Function details of the tool
+ function: dict
+
+
+class ChatCompletionComplete(BaseModel):
+ # Index of the chat completion
+ index: int
+ # Message associated with the chat completion
+ message: Message
+ # Log probabilities for the chat completion
+ logprobs: Optional[Any]
+ # Reason for completion
+ finish_reason: str
+ # Usage details of the chat completion
+ usage: Any
+
+
+class Function(BaseModel):
+ name: Optional[str]
+ arguments: str
+
+
+class ChoiceDeltaToolCall(BaseModel):
+ index: int
+ id: str
+ type: str
+ function: Function
+
+
+class ChoiceDelta(BaseModel):
+ role: str
+ content: Optional[str]
+ tool_calls: Optional[ChoiceDeltaToolCall]
+
+
+class Choice(BaseModel):
+ index: int
+ delta: ChoiceDelta
+ logprobs: Optional[dict] = None
+ finish_reason: Optional[str] = None
+
+
+class ChatCompletionChunk(BaseModel):
+ id: str
+ object: str
+ created: int
+ model: str
+ system_fingerprint: str
+ choices: List[Choice]
+
+
+class ChatComplete(BaseModel):
+ # Chat completion details
+ id: str
+ object: str
+ created: int
+ model: str
+ system_fingerprint: str
+ choices: List[ChatCompletionComplete]
+ usage: Any
+
+
+class ChatRequest(BaseModel):
+ # Model identifier
+ model: str
+ # List of messages in the conversation
+ messages: List[Message]
+ # Penalty for frequency of new tokens
+ frequency_penalty: Optional[float] = None
+ # Bias values for token selection
+ logit_bias: Optional[List[float]] = None
+ # Whether to return log probabilities
+ logprobs: Optional[bool] = None
+ # Number of most likely tokens to return at each position
+ top_logprobs: Optional[int] = None
+ # Maximum number of tokens to generate
+ max_tokens: Optional[int] = None
+ # Number of chat completion choices to generate
+ n: Optional[int] = None
+ # Penalty for presence of new tokens
+ presence_penalty: Optional[float] = None
+ # Flag to indicate streaming response
+ stream: bool = False
+ # Random sampling seed
+ seed: Optional[int] = None
+ # Sampling temperature
+ temperature: Optional[float] = None
+ # Top-p value for nucleus sampling
+ top_p: Optional[float] = None
+ # List of tools to be used
+ tools: Optional[List[Tool]] = None
+ # Choice of tool to be used
+ tool_choice: Optional[str] = None
+
+
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index d57a594d..964a743a 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -9,6 +9,8 @@
title: Supported Models and Hardware
- local: messages_api
title: Messages API
+ - local: guidance
+ title: Guidance
title: Getting started
- sections:
- local: basic_tutorials/consuming_tgi
diff --git a/docs/source/guidance.md b/docs/source/guidance.md
new file mode 100644
index 00000000..8b9ba094
--- /dev/null
+++ b/docs/source/guidance.md
@@ -0,0 +1,419 @@
+# Guidance
+
+Text Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developer guide LLM responses to fit their needs.
+
+These feature are available starting from version `1.4.3`. They are accessible via the [text_generation](https://pypi.org/project/text-generation/) library and is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!
+
+## Quick Start
+
+Before we jump into the deep end, ensure your system is using TGI version `1.4.3` or later to access all the features we're about to explore in this guide.
+
+If you're not up to date, grab the latest version and let's get started!
+
+## Table of Contents 📚
+
+### Grammar and Constraints
+
+- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.
+- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.
+- [JSON Schema Integration](#json-schema-integration): Fine grain control over your requests via JSON schema.
+- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.
+
+### Tools and Functions
+
+- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.
+- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.
+- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
+
+## Grammar and Constraints 🛣️
+
+### The Grammar Parameter
+
+In TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the AI. This is a game-changer for those who need precise control over the AI's output.
+
+Using curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.
+
+```json
+curl localhost:3000/generate \
+ -X POST \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "inputs": "I saw a puppy a cat and a raccoon during my bike ride in the park",
+ "parameters": {
+ "repetition_penalty": 1.3,
+ "grammar": {
+ "type": "json",
+ "value": {
+ "properties": {
+ "location": {
+ "type": "string"
+ },
+ "activity": {
+ "type": "string"
+ },
+ "animals_seen": {
+ "type": "integer",
+ "minimum": 1,
+ "maximum": 5
+ },
+ "animals": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "required": ["location", "activity", "animals_seen", "animals"]
+ }
+ }
+ }
+}'
+// {"generated_text":"{ \n\n\"activity\": \"biking\",\n\"animals\": [\"puppy\",\"cat\",\"raccoon\"],\n\"animals_seen\": 3,\n\"location\": \"park\"\n}"}
+
+```
+
+A grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.
+
+> Note: A grammar must compile to a intermediate representation to constrain the output. Grammar compliation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.
+
+### Constrain with Pydantic
+
+Pydantic is a powerful library for data validation and settings management. It's the perfect tool for crafting the a specific response format.
+
+Using Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.
+
+```python
+import requests
+from pydantic import BaseModel, conint
+from typing import List
+
+class Animals(BaseModel):
+ location: str
+ activity: str
+ animals_seen: conint(ge=1, le=5) # Constrained integer type
+ animals: List[str]
+
+prompt = "convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park"
+
+data = {
+ "inputs": prompt,
+ "parameters": {
+ "repetition_penalty": 1.3,
+ "grammar": {
+ "type": "json",
+ "value": Animals.schema()
+ }
+ }
+}
+
+headers = {
+ "Content-Type": "application/json",
+}
+
+response = requests.post(
+ 'http://127.0.0.1:3000/generate',
+ headers=headers,
+ json=data
+)
+print(response.json())
+# {'generated_text': '{ "activity": "bike riding", "animals": ["puppy","cat","raccoon"],"animals_seen": 3, "location":"park" }'}
+
+```
+
+### JSON Schema Integration
+
+If Pydantic's not your style, go raw with direct JSON Schema integration. It's like having a conversation with the AI in its own language. This is simliar to the first example but with programmatic control.
+
+```python
+import requests
+
+json_schema = {
+ "properties": {
+ "location": {
+ "type": "string"
+ },
+ "activity": {
+ "type": "string"
+ },
+ "animals_seen": {
+ "type": "integer",
+ "minimum": 1,
+ "maximum": 5
+ },
+ "animals": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "required": ["location", "activity", "animals_seen", "animals"]
+}
+
+data = {
+ "inputs": "[INST]convert to JSON: I saw a puppy a cat and a raccoon during my bike ride in the park [/INST]",
+ "parameters": {
+ "max_new_tokens": 200,
+ "repetition_penalty": 1.3,
+ "grammar": {
+ "type": "json",
+ "value": json_schema
+ }
+ }
+}
+
+headers = {
+ "Content-Type": "application/json",
+}
+
+response = requests.post(
+ 'http://127.0.0.1:3000/generate',
+ headers=headers,
+ json=data
+)
+print(response.json())
+# {'generated_text': '{\n"activity": "biking",\n"animals": ["puppy","cat","raccoon"]\n , "animals_seen": 3,\n "location":"park"}'}
+
+```
+
+### Using the client
+
+TGI provides a client library to that make it easy to send requests with all of the parameters we've discussed above. Here's an example of how to use the client to send a request with a grammar parameter.
+
+```python
+from text_generation import AsyncClient
+from text_generation.types import GrammarType
+
+# NOTE: tools defined above and removed for brevity
+
+# Define an async function to encapsulate the async operation
+async def main():
+ client = AsyncClient(base_url="http://localhost:3000")
+
+ # Use 'await' to wait for the async method 'chat' to complete
+ response = await client.generate(
+ "Whats Googles DNS",
+ max_new_tokens=10,
+ decoder_input_details=True,
+ seed=1,
+ grammar={
+ "type": GrammarType.Regex,
+ "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
+ },
+ )
+
+ # Once the response is received, you can process it
+ print(response.generated_text)
+
+# Ensure the main async function is run in the event loop
+if __name__ == "__main__":
+ import asyncio
+ asyncio.run(main())
+
+# 118.8.0.84
+
+```
+
+## Tools and Functions 🛠️
+
+### The Tools Parameter
+
+In addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.
+
+Tools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the AI's capabilities. You can use these tools to perform a variety of tasks, such as data manipulation, formatting, and more.
+
+Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.
+
+```json
+curl localhost:3000/v1/chat/completions \
+ -X POST \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "model": "tgi",
+ "messages": [
+ {
+ "role": "user",
+ "content": "What is the weather like in New York?"
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA"
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location."
+ }
+ },
+ "required": ["location", "format"]
+ }
+ }
+ }
+ ],
+ "tool_choice": "get_current_weather"
+}'
+// {"id":"","object":"text_completion","created":1709051640,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":19,"total_tokens":176}}
+```
+
+
+ Tools used in example below
+
+ ```python
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ },
+ "required": ["location", "format"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_n_day_weather_forecast",
+ "description": "Get an N-day weather forecast",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ "num_days": {
+ "type": "integer",
+ "description": "The number of days to forecast",
+ },
+ },
+ "required": ["location", "format", "num_days"],
+ },
+ },
+ }
+ ]
+ ```
+
+
+
+### Text Generation Inference Client
+
+TGI provides a client library to interact with the Messages API and Tool functions. The client library is available in both synchronous and asynchronous versions.
+
+```python
+from text_generation import AsyncClient
+
+# NOTE: tools defined above and removed for brevity
+
+# Define an async function to encapsulate the async operation
+async def main():
+ client = AsyncClient(base_url="http://localhost:3000")
+
+ # Use 'await' to wait for the async method 'chat' to complete
+ response = await client.chat(
+ max_tokens=100,
+ seed=1,
+ tools=tools,
+ presence_penalty=-1.1,
+ messages=[
+ {
+ "role": "system",
+ "content": "You're a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Brooklyn, New York?",
+ },
+ ],
+ )
+
+ # Once the response is received, you can process it
+ print(response.choices[0].message.tool_calls)
+
+# Ensure the main async function is run in the event loop
+if __name__ == "__main__":
+ import asyncio
+ asyncio.run(main())
+
+# {"id":"","object":"text_completion","created":1709051942,"model":"HuggingFaceH4/zephyr-7b-beta","system_fingerprint":"1.4.2-native","choices":[{"index":0,"message":{"role":"assistant","tool_calls":{"id":0,"type":"function","function":{"description":null,"name":"tools","parameters":{"format":"celsius","location":"New York"}}}},"logprobs":null,"finish_reason":"eos_token"}],"usage":{"prompt_tokens":157,"completion_tokens":20,"total_tokens":177}}
+
+```
+
+### OpenAI integration
+
+TGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.
+
+However there are some minor differences in the API, for example `tool_choice="auto"` will ALWAYS choose the tool for you. This is different from OpenAI's API where `tool_choice="auto"` will choose a tool if the model thinks it's necessary.
+
+```python
+from openai import OpenAI
+
+# Initialize the client, pointing it to one of the available models
+client = OpenAI(
+ base_url="http://localhost:3000/v1",
+ api_key="_",
+)
+
+# NOTE: tools defined above and removed for brevity
+
+chat_completion = client.chat.completions.create(
+ model="tgi",
+ messages=[
+ {
+ "role": "system",
+ "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
+ },
+ {
+ "role": "user",
+ "content": "What's the weather like the next 3 days in San Francisco, CA?",
+ },
+ ],
+ tools=tools,
+ tool_choice="auto", # tool selected by model
+ max_tokens=500,
+)
+
+
+called = chat_completion.choices[0].message.tool_calls
+print(called)
+# {
+# "id": 0,
+# "type": "function",
+# "function": {
+# "description": None,
+# "name": "tools",
+# "parameters": {
+# "format": "celsius",
+# "location": "San Francisco, CA",
+# "num_days": 3,
+# },
+# },
+# }
+```
diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py
index e11c7cf9..96cf43ad 100644
--- a/integration-tests/conftest.py
+++ b/integration-tests/conftest.py
@@ -23,6 +23,8 @@ from text_generation.types import (
Token,
BestOfSequence,
Grammar,
+ ChatComplete,
+ ChatCompletionChunk,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@@ -59,6 +61,15 @@ class ResponseComparator(JSONSnapshotExtension):
) -> bool:
def convert_data(data):
data = json.loads(data)
+ if isinstance(data, Dict) and "choices" in data:
+ choices = data["choices"]
+ if (
+ isinstance(choices, List)
+ and len(choices) >= 1
+ and "delta" in choices[0]
+ ):
+ return ChatCompletionChunk(**data)
+ return ChatComplete(**data)
if isinstance(data, Dict):
return Response(**data)
@@ -144,6 +155,16 @@ class ResponseComparator(JSONSnapshotExtension):
)
)
+ def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
+ return (
+ response.choices[0].message.content == other.choices[0].message.content
+ )
+
+ def eq_chat_complete_chunk(
+ response: ChatCompletionChunk, other: ChatCompletionChunk
+ ) -> bool:
+ return response.choices[0].delta.content == other.choices[0].delta.content
+
def eq_response(response: Response, other: Response) -> bool:
return response.generated_text == other.generated_text and eq_details(
response.details, other.details
@@ -157,6 +178,19 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data]
+ if isinstance(serialized_data[0], ChatComplete):
+ return len(snapshot_data) == len(serialized_data) and all(
+ [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
+ )
+
+ if isinstance(serialized_data[0], ChatCompletionChunk):
+ return len(snapshot_data) == len(serialized_data) and all(
+ [
+ eq_chat_complete_chunk(r, o)
+ for r, o in zip(serialized_data, snapshot_data)
+ ]
+ )
+
return len(snapshot_data) == len(serialized_data) and all(
[eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json
new file mode 100644
index 00000000..3c4b4aea
--- /dev/null
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_no_tools.json
@@ -0,0 +1,26 @@
+{
+ "choices": [
+ {
+ "finish_reason": "length",
+ "index": 0,
+ "logprobs": null,
+ "message": {
+ "content": "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally",
+ "name": null,
+ "role": "assistant",
+ "tool_calls": null
+ },
+ "usage": null
+ }
+ ],
+ "created": 1708957015,
+ "id": "",
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "object": "text_completion",
+ "system_fingerprint": "1.4.2-native",
+ "usage": {
+ "completion_tokens": 100,
+ "prompt_tokens": 60,
+ "total_tokens": 160
+ }
+}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json
new file mode 100644
index 00000000..9b9e33c6
--- /dev/null
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools.json
@@ -0,0 +1,38 @@
+{
+ "choices": [
+ {
+ "finish_reason": "eos_token",
+ "index": 0,
+ "logprobs": null,
+ "message": {
+ "content": null,
+ "name": null,
+ "role": "assistant",
+ "tool_calls": {
+ "function": {
+ "description": null,
+ "name": "tools",
+ "parameters": {
+ "format": "celsius",
+ "location": "New York, NY",
+ "num_days": 14
+ }
+ },
+ "id": 0,
+ "type": "function"
+ }
+ },
+ "usage": null
+ }
+ ],
+ "created": 1709079417,
+ "id": "",
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "object": "text_completion",
+ "system_fingerprint": "1.4.2-native",
+ "usage": {
+ "completion_tokens": 29,
+ "prompt_tokens": 316,
+ "total_tokens": 345
+ }
+}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json
new file mode 100644
index 00000000..de32c970
--- /dev/null
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto.json
@@ -0,0 +1,38 @@
+{
+ "choices": [
+ {
+ "finish_reason": "eos_token",
+ "index": 0,
+ "logprobs": null,
+ "message": {
+ "content": null,
+ "name": null,
+ "role": "assistant",
+ "tool_calls": {
+ "function": {
+ "description": null,
+ "name": "tools",
+ "parameters": {
+ "format": "celsius",
+ "location": "New York, NY",
+ "num_days": 14
+ }
+ },
+ "id": 0,
+ "type": "function"
+ }
+ },
+ "usage": null
+ }
+ ],
+ "created": 1709079492,
+ "id": "",
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "object": "text_completion",
+ "system_fingerprint": "1.4.2-native",
+ "usage": {
+ "completion_tokens": 29,
+ "prompt_tokens": 316,
+ "total_tokens": 345
+ }
+}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json
new file mode 100644
index 00000000..3551e205
--- /dev/null
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice.json
@@ -0,0 +1,37 @@
+{
+ "choices": [
+ {
+ "finish_reason": "eos_token",
+ "index": 0,
+ "logprobs": null,
+ "message": {
+ "content": null,
+ "name": null,
+ "role": "assistant",
+ "tool_calls": {
+ "function": {
+ "description": null,
+ "name": "tools",
+ "parameters": {
+ "format": "celsius",
+ "location": "New York, NY"
+ }
+ },
+ "id": 0,
+ "type": "function"
+ }
+ },
+ "usage": null
+ }
+ ],
+ "created": 1709079493,
+ "id": "",
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "object": "text_completion",
+ "system_fingerprint": "1.4.2-native",
+ "usage": {
+ "completion_tokens": 21,
+ "prompt_tokens": 187,
+ "total_tokens": 208
+ }
+}
diff --git a/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json
new file mode 100644
index 00000000..c367cc6f
--- /dev/null
+++ b/integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_stream.json
@@ -0,0 +1,27 @@
+{
+ "choices": [
+ {
+ "delta": {
+ "content": null,
+ "role": "assistant",
+ "tool_calls": {
+ "function": {
+ "arguments": "",
+ "name": null
+ },
+ "id": "",
+ "index": 20,
+ "type": "function"
+ }
+ },
+ "finish_reason": "eos_token",
+ "index": 20,
+ "logprobs": null
+ }
+ ],
+ "created": 1709087088,
+ "id": "",
+ "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
+ "object": "text_completion",
+ "system_fingerprint": "1.4.2-native"
+}
diff --git a/integration-tests/models/test_tools_llama.py b/integration-tests/models/test_tools_llama.py
new file mode 100644
index 00000000..38570c38
--- /dev/null
+++ b/integration-tests/models/test_tools_llama.py
@@ -0,0 +1,240 @@
+import pytest
+import json
+
+from text_generation.types import GrammarType
+
+
+@pytest.fixture(scope="module")
+def flash_llama_grammar_tools_handle(launcher):
+ with launcher(
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
+ ) as handle:
+ yield handle
+
+
+@pytest.fixture(scope="module")
+async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
+ await flash_llama_grammar_tools_handle.health(300)
+ return flash_llama_grammar_tools_handle.client
+
+
+# tools to be used in the following tests
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ },
+ "required": ["location", "format"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_n_day_weather_forecast",
+ "description": "Get an N-day weather forecast",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "format": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The temperature unit to use. Infer this from the users location.",
+ },
+ "num_days": {
+ "type": "integer",
+ "description": "The number of days to forecast",
+ },
+ },
+ "required": ["location", "format", "num_days"],
+ },
+ },
+ },
+]
+
+
+@pytest.mark.asyncio
+@pytest.mark.private
+async def test_flash_llama_grammar_no_tools(
+ flash_llama_grammar_tools, response_snapshot
+):
+ response = await flash_llama_grammar_tools.chat(
+ max_tokens=100,
+ seed=1,
+ messages=[
+ {
+ "role": "system",
+ "content": "Youre a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Brooklyn, New York?",
+ },
+ ],
+ )
+
+ assert (
+ response.choices[0].message.content
+ == "As of today, there is a Update available for the Brooklyn, New York, area. According to the latest forecast, it's warm with high temperatures throughout the day. It's forecasted at 75°F for today and 77°F for tomorrow. However, in autumn, the weather typically changes drastically, becoming cooler and wetter. You can find the current weather forecast for the area through your local weather service. Additionally"
+ )
+ assert response == response_snapshot
+
+
+@pytest.mark.asyncio
+@pytest.mark.private
+async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
+ response = await flash_llama_grammar_tools.chat(
+ max_tokens=100,
+ seed=1,
+ tools=tools,
+ presence_penalty=-1.1,
+ messages=[
+ {
+ "role": "system",
+ "content": "Youre a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Brooklyn, New York?",
+ },
+ ],
+ )
+ assert response.choices[0].message.content == None
+ assert response.choices[0].message.tool_calls == {
+ "function": {
+ "description": None,
+ "name": "tools",
+ "parameters": {
+ "format": "celsius",
+ "location": "New York, NY",
+ "num_days": 14,
+ },
+ },
+ "id": 0,
+ "type": "function",
+ }
+ assert response == response_snapshot
+
+
+@pytest.mark.asyncio
+@pytest.mark.private
+async def test_flash_llama_grammar_tools_auto(
+ flash_llama_grammar_tools, response_snapshot
+):
+ response = await flash_llama_grammar_tools.chat(
+ max_tokens=100,
+ seed=1,
+ tools=tools,
+ tool_choice="auto",
+ presence_penalty=-1.1,
+ messages=[
+ {
+ "role": "system",
+ "content": "Youre a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Brooklyn, New York?",
+ },
+ ],
+ )
+ assert response.choices[0].message.content == None
+ assert response.choices[0].message.tool_calls == {
+ "function": {
+ "description": None,
+ "name": "tools",
+ "parameters": {
+ "format": "celsius",
+ "location": "New York, NY",
+ "num_days": 14,
+ },
+ },
+ "id": 0,
+ "type": "function",
+ }
+ assert response == response_snapshot
+
+
+@pytest.mark.asyncio
+@pytest.mark.private
+async def test_flash_llama_grammar_tools_choice(
+ flash_llama_grammar_tools, response_snapshot
+):
+ response = await flash_llama_grammar_tools.chat(
+ max_tokens=100,
+ seed=1,
+ tools=tools,
+ tool_choice="get_current_weather",
+ presence_penalty=-1.1,
+ messages=[
+ {
+ "role": "system",
+ "content": "Youre a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Brooklyn, New York?",
+ },
+ ],
+ )
+ assert response.choices[0].message.content == None
+ assert response.choices[0].message.tool_calls == {
+ "id": 0,
+ "type": "function",
+ "function": {
+ "description": None,
+ "name": "tools",
+ "parameters": {"format": "celsius", "location": "New York, NY"},
+ },
+ }
+ assert response == response_snapshot
+
+
+@pytest.mark.asyncio
+@pytest.mark.private
+async def test_flash_llama_grammar_tools_stream(
+ flash_llama_grammar_tools, response_snapshot
+):
+ responses = await flash_llama_grammar_tools.chat(
+ max_tokens=100,
+ seed=1,
+ tools=tools,
+ tool_choice="get_current_weather",
+ presence_penalty=-1.1,
+ messages=[
+ {
+ "role": "system",
+ "content": "Youre a helpful assistant! Answer the users question best you can.",
+ },
+ {
+ "role": "user",
+ "content": "What is the weather like in Paris, France?",
+ },
+ ],
+ stream=True,
+ )
+
+ count = 0
+ async for response in responses:
+ count += 1
+
+ assert count == 20
+ assert response == response_snapshot
diff --git a/router/src/infer.rs b/router/src/infer.rs
index 472b7d66..42405327 100644
--- a/router/src/infer.rs
+++ b/router/src/infer.rs
@@ -812,23 +812,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
- content: "Hi!".to_string(),
+ content: Some("Hi!".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "Hello how can I help?".to_string(),
+ content: Some("Hello how can I help?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "user".to_string(),
- content: "What is Deep Learning?".to_string(),
+ content: Some("What is Deep Learning?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "magic!".to_string(),
+ content: Some("magic!".to_string()),
name: None,
+ tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@@ -877,28 +881,33 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
- content: "Hi!".to_string(),
+ content: Some("Hi!".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "user".to_string(),
- content: "Hi again!".to_string(),
+ content: Some("Hi again!".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "Hello how can I help?".to_string(),
+ content: Some("Hello how can I help?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "user".to_string(),
- content: "What is Deep Learning?".to_string(),
+ content: Some("What is Deep Learning?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "magic!".to_string(),
+ content: Some("magic!".to_string()),
name: None,
+ tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@@ -952,23 +961,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
- content: "Hi!".to_string(),
+ content: Some("Hi!".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "Hello how can I help?".to_string(),
+ content: Some("Hello how can I help?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "user".to_string(),
- content: "What is Deep Learning?".to_string(),
+ content: Some("What is Deep Learning?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "magic!".to_string(),
+ content: Some("magic!".to_string()),
name: None,
+ tool_calls: None,
},
],
bos_token: Some("[BOS]"),
@@ -1006,23 +1019,27 @@ mod tests {
messages: vec![
Message {
role: "user".to_string(),
- content: "Hi!".to_string(),
+ content: Some("Hi!".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "Hello how can I help?".to_string(),
+ content: Some("Hello how can I help?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "user".to_string(),
- content: "What is Deep Learning?".to_string(),
+ content: Some("What is Deep Learning?".to_string()),
name: None,
+ tool_calls: None,
},
Message {
role: "assistant".to_string(),
- content: "magic!".to_string(),
+ content: Some("magic!".to_string()),
name: None,
+ tool_calls: None,
},
],
bos_token: Some("[BOS]"),
diff --git a/router/src/lib.rs b/router/src/lib.rs
index 1c06eb8a..d89bacb5 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -358,10 +358,11 @@ impl ChatCompletion {
pub(crate) fn new(
model: String,
system_fingerprint: String,
- output: String,
+ output: Option,
created: u64,
details: Details,
return_logprobs: bool,
+ tool_calls: Option,
) -> Self {
Self {
id: String::new(),
@@ -375,6 +376,7 @@ impl ChatCompletion {
role: "assistant".into(),
content: output,
name: None,
+ tool_calls,
},
logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@@ -413,15 +415,35 @@ pub(crate) struct ChatCompletionChoice {
pub(crate) struct ChatCompletionDelta {
#[schema(example = "user")]
pub role: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
- pub content: String,
+ pub content: Option,
+ // default to None
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_calls: Option,
}
+#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
+pub(crate) struct DeltaToolCall {
+ pub index: u32,
+ pub id: String,
+ pub r#type: String,
+ pub function: Function,
+}
+
+#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
+pub(crate) struct Function {
+ pub name: Option,
+ pub arguments: String,
+}
+
+#[allow(clippy::too_many_arguments)]
impl ChatCompletionChunk {
pub(crate) fn new(
model: String,
system_fingerprint: String,
- delta: String,
+ delta: Option,
+ tool_calls: Option>,
created: u64,
index: u32,
logprobs: Option,
@@ -438,6 +460,15 @@ impl ChatCompletionChunk {
delta: ChatCompletionDelta {
role: "assistant".to_string(),
content: delta,
+ tool_calls: tool_calls.map(|tc| DeltaToolCall {
+ index,
+ id: String::new(),
+ r#type: "function".to_string(),
+ function: Function {
+ name: None,
+ arguments: tc[0].to_string(),
+ },
+ }),
},
logprobs,
finish_reason,
@@ -520,6 +551,125 @@ pub(crate) struct ChatRequest {
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option,
+
+ /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
+ /// functions the model may generate JSON inputs for.
+ #[serde(default)]
+ #[schema(nullable = true, example = "null")]
+ pub tools: Option>,
+
+ /// A prompt to be appended before the tools
+ #[serde(default = "default_tool_prompt")]
+ #[schema(
+ nullable = true,
+ example = "\"Based on the conversation, please choose the most appropriate tool to use: \""
+ )]
+ pub tool_prompt: Option,
+
+ /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
+ #[serde(default)]
+ #[schema(nullable = true, example = "null")]
+ #[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
+ pub tool_choice: Option,
+}
+
+fn default_tool_prompt() -> Option {
+ Some(
+ "\nBased on the conversation, please choose the most appropriate tool to use: ".to_string(),
+ )
+}
+#[derive(Clone, Deserialize, ToSchema, Serialize)]
+enum ToolType {
+ FunctionName(String),
+ OneOf,
+}
+
+/// Deserialize the tool choice from the JSON input or from the function name ("none" is allowed but mapped to None)
+mod deserialize_tool_choice {
+ use super::*;
+ use serde::de;
+ use serde::Deserializer;
+ use serde_json::Value;
+
+ pub fn deserialize<'de, D>(deserializer: D) -> Result