257 lines
7.6 KiB
Python
257 lines
7.6 KiB
Python
import pytest
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def flash_llama_grammar_tools_handle(launcher):
|
|
with launcher(
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
|
) as handle:
|
|
yield handle
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
|
|
await flash_llama_grammar_tools_handle.health(300)
|
|
return flash_llama_grammar_tools_handle.client
|
|
|
|
|
|
# tools to be used in the following tests
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_weather",
|
|
"description": "Get the current weather",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA",
|
|
},
|
|
"format": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
"description": "The temperature unit to use. Infer this from the users location.",
|
|
},
|
|
},
|
|
"required": ["location", "format"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_n_day_weather_forecast",
|
|
"description": "Get an N-day weather forecast",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"location": {
|
|
"type": "string",
|
|
"description": "The city and state, e.g. San Francisco, CA",
|
|
},
|
|
"format": {
|
|
"type": "string",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
"description": "The temperature unit to use. Infer this from the users location.",
|
|
},
|
|
"num_days": {
|
|
"type": "integer",
|
|
"description": "The number of days to forecast",
|
|
},
|
|
},
|
|
"required": ["location", "format", "num_days"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
|
response = await flash_llama_grammar_tools.chat(
|
|
max_tokens=100,
|
|
seed=1,
|
|
tools=tools,
|
|
temperature=0.0,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the weather like in Brooklyn, New York?",
|
|
},
|
|
],
|
|
)
|
|
assert response.choices[0].message.content is None
|
|
assert response.choices[0].message.tool_calls == [
|
|
{
|
|
"id": "0",
|
|
"type": "function",
|
|
"function": {
|
|
"description": None,
|
|
"name": "get_current_weather",
|
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
|
},
|
|
}
|
|
]
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_grammar_tools_auto(
|
|
flash_llama_grammar_tools, response_snapshot
|
|
):
|
|
response = await flash_llama_grammar_tools.chat(
|
|
max_tokens=100,
|
|
seed=1,
|
|
tools=tools,
|
|
temperature=0.0,
|
|
tool_choice="auto",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the weather like in Brooklyn, New York?",
|
|
},
|
|
],
|
|
)
|
|
assert response.choices[0].message.content is None
|
|
assert response.choices[0].message.tool_calls == [
|
|
{
|
|
"id": "0",
|
|
"type": "function",
|
|
"function": {
|
|
"description": None,
|
|
"name": "get_current_weather",
|
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
|
},
|
|
}
|
|
]
|
|
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_grammar_tools_choice(
|
|
flash_llama_grammar_tools, response_snapshot
|
|
):
|
|
response = await flash_llama_grammar_tools.chat(
|
|
max_tokens=100,
|
|
seed=1,
|
|
tools=tools,
|
|
temperature=0.0,
|
|
tool_choice="get_current_weather",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the weather like in Brooklyn, New York?",
|
|
},
|
|
],
|
|
)
|
|
assert response.choices[0].message.content is None
|
|
assert response.choices[0].message.tool_calls == [
|
|
{
|
|
"id": "0",
|
|
"type": "function",
|
|
"function": {
|
|
"description": None,
|
|
"name": "get_current_weather",
|
|
"arguments": {"format": "celsius", "location": "Brooklyn, NY"},
|
|
},
|
|
}
|
|
]
|
|
|
|
assert response == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_grammar_tools_stream(
|
|
flash_llama_grammar_tools, response_snapshot
|
|
):
|
|
responses = await flash_llama_grammar_tools.chat(
|
|
max_tokens=100,
|
|
seed=1,
|
|
tools=tools,
|
|
temperature=0.0,
|
|
tool_choice="get_current_weather",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "Youre a helpful assistant! Answer the users question best you can.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the weather like in Paris, France?",
|
|
},
|
|
],
|
|
stream=True,
|
|
)
|
|
|
|
count = 0
|
|
tool_call = ""
|
|
string = ""
|
|
all_responses = []
|
|
async for response in responses:
|
|
count += 1
|
|
content = response.choices[0].delta.content
|
|
if content is None:
|
|
tool_call += response.choices[0].delta.tool_calls.function.arguments
|
|
else:
|
|
string += content
|
|
all_responses.append(response)
|
|
|
|
assert count == 39
|
|
assert string == ""
|
|
assert (
|
|
tool_call
|
|
== '{ "function": { "_name": "get_current_weather", "format": "celsius", "location": "Paris, France" } }</s>'
|
|
)
|
|
assert all_responses == response_snapshot
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.private
|
|
async def test_flash_llama_grammar_tools_insufficient_information(
|
|
flash_llama_grammar_tools, response_snapshot
|
|
):
|
|
responses = await flash_llama_grammar_tools.chat(
|
|
max_tokens=100,
|
|
seed=24,
|
|
tools=tools,
|
|
tool_choice="auto",
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Tell me a story about 3 sea creatures",
|
|
},
|
|
],
|
|
stream=False,
|
|
)
|
|
|
|
assert responses.choices[0].message.content is None
|
|
assert (
|
|
responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
|
|
)
|
|
assert responses == response_snapshot
|