2024-02-28 03:10:27 -07:00
import pytest
@pytest.fixture ( scope = " module " )
def flash_llama_grammar_tools_handle ( launcher ) :
with launcher (
2024-09-30 03:15:09 -06:00
" meta-llama/Meta-Llama-3.1-8B-Instruct " ,
num_shard = 2 ,
disable_grammar_support = False ,
2024-02-28 03:10:27 -07:00
) as handle :
yield handle
@pytest.fixture ( scope = " module " )
async def flash_llama_grammar_tools ( flash_llama_grammar_tools_handle ) :
await flash_llama_grammar_tools_handle . health ( 300 )
return flash_llama_grammar_tools_handle . client
# tools to be used in the following tests
tools = [
{
" type " : " function " ,
" function " : {
" name " : " get_current_weather " ,
" description " : " Get the current weather " ,
" parameters " : {
" type " : " object " ,
" properties " : {
" location " : {
" type " : " string " ,
" description " : " The city and state, e.g. San Francisco, CA " ,
} ,
" format " : {
" type " : " string " ,
" enum " : [ " celsius " , " fahrenheit " ] ,
" description " : " The temperature unit to use. Infer this from the users location. " ,
} ,
} ,
" required " : [ " location " , " format " ] ,
2024-08-26 18:19:38 -06:00
" additionalProperties " : False ,
2024-02-28 03:10:27 -07:00
} ,
} ,
} ,
{
" type " : " function " ,
" function " : {
" name " : " get_n_day_weather_forecast " ,
" description " : " Get an N-day weather forecast " ,
" parameters " : {
" type " : " object " ,
" properties " : {
" location " : {
" type " : " string " ,
" description " : " The city and state, e.g. San Francisco, CA " ,
} ,
" format " : {
" type " : " string " ,
" enum " : [ " celsius " , " fahrenheit " ] ,
" description " : " The temperature unit to use. Infer this from the users location. " ,
} ,
" num_days " : {
" type " : " integer " ,
" description " : " The number of days to forecast " ,
} ,
} ,
" required " : [ " location " , " format " , " num_days " ] ,
2024-08-26 18:19:38 -06:00
" additionalProperties " : False ,
2024-02-28 03:10:27 -07:00
} ,
} ,
} ,
]
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools ( flash_llama_grammar_tools , response_snapshot ) :
response = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 1 ,
tools = tools ,
2024-08-26 18:19:38 -06:00
temperature = 0.0 ,
2024-02-28 03:10:27 -07:00
messages = [
{
" role " : " system " ,
" content " : " Youre a helpful assistant! Answer the users question best you can. " ,
} ,
{
" role " : " user " ,
" content " : " What is the weather like in Brooklyn, New York? " ,
} ,
] ,
)
2024-07-26 08:29:09 -06:00
assert response . choices [ 0 ] . message . content is None
2024-03-21 10:45:56 -06:00
assert response . choices [ 0 ] . message . tool_calls == [
{
2024-08-26 18:19:38 -06:00
" id " : " 0 " ,
2024-04-16 07:02:46 -06:00
" type " : " function " ,
2024-03-21 10:45:56 -06:00
" function " : {
" description " : None ,
2024-04-16 07:02:46 -06:00
" name " : " get_current_weather " ,
2024-08-26 18:19:38 -06:00
" arguments " : { " format " : " celsius " , " location " : " Brooklyn, NY " } ,
2024-02-28 03:10:27 -07:00
} ,
2024-03-21 10:45:56 -06:00
}
]
2024-02-28 03:10:27 -07:00
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto (
flash_llama_grammar_tools , response_snapshot
) :
response = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 1 ,
tools = tools ,
2024-08-26 18:19:38 -06:00
temperature = 0.0 ,
2024-02-28 03:10:27 -07:00
tool_choice = " auto " ,
messages = [
{
" role " : " system " ,
" content " : " Youre a helpful assistant! Answer the users question best you can. " ,
} ,
{
" role " : " user " ,
" content " : " What is the weather like in Brooklyn, New York? " ,
} ,
] ,
)
2024-07-26 08:29:09 -06:00
assert response . choices [ 0 ] . message . content is None
2024-03-21 10:45:56 -06:00
assert response . choices [ 0 ] . message . tool_calls == [
{
2024-08-26 18:19:38 -06:00
" id " : " 0 " ,
2024-04-16 07:02:46 -06:00
" type " : " function " ,
2024-03-21 10:45:56 -06:00
" function " : {
" description " : None ,
2024-04-16 07:02:46 -06:00
" name " : " get_current_weather " ,
2024-08-26 18:19:38 -06:00
" arguments " : { " format " : " celsius " , " location " : " Brooklyn, NY " } ,
2024-02-28 03:10:27 -07:00
} ,
2024-03-21 10:45:56 -06:00
}
]
2024-04-16 07:02:46 -06:00
2024-02-28 03:10:27 -07:00
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice (
flash_llama_grammar_tools , response_snapshot
) :
response = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 1 ,
tools = tools ,
2024-08-26 18:19:38 -06:00
temperature = 0.0 ,
2024-02-28 03:10:27 -07:00
tool_choice = " get_current_weather " ,
messages = [
{
" role " : " system " ,
" content " : " Youre a helpful assistant! Answer the users question best you can. " ,
} ,
{
" role " : " user " ,
" content " : " What is the weather like in Brooklyn, New York? " ,
} ,
] ,
)
2024-07-26 08:29:09 -06:00
assert response . choices [ 0 ] . message . content is None
2024-03-21 10:45:56 -06:00
assert response . choices [ 0 ] . message . tool_calls == [
{
2024-08-26 18:19:38 -06:00
" id " : " 0 " ,
2024-03-21 10:45:56 -06:00
" type " : " function " ,
" function " : {
" description " : None ,
2024-04-16 07:02:46 -06:00
" name " : " get_current_weather " ,
2024-08-26 18:19:38 -06:00
" arguments " : { " format " : " celsius " , " location " : " Brooklyn, NY " } ,
2024-03-21 10:45:56 -06:00
} ,
}
]
2024-04-16 07:02:46 -06:00
2024-02-28 03:10:27 -07:00
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream (
flash_llama_grammar_tools , response_snapshot
) :
responses = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 1 ,
tools = tools ,
2024-08-26 18:19:38 -06:00
temperature = 0.0 ,
2024-02-28 03:10:27 -07:00
tool_choice = " get_current_weather " ,
messages = [
{
" role " : " system " ,
" content " : " Youre a helpful assistant! Answer the users question best you can. " ,
} ,
{
" role " : " user " ,
" content " : " What is the weather like in Paris, France? " ,
} ,
] ,
stream = True ,
)
count = 0
2024-10-10 07:28:25 -06:00
tool_calls_generated = " "
last_response = None
2024-02-28 03:10:27 -07:00
async for response in responses :
count + = 1
2024-10-10 07:28:25 -06:00
tool_calls_generated + = response . choices [ 0 ] . delta . tool_calls . function . arguments
last_response = response
assert response . choices [ 0 ] . delta . content is None
2024-02-28 03:10:27 -07:00
2024-10-10 07:28:25 -06:00
assert (
tool_calls_generated
== ' { " function " : { " _name " : " get_current_weather " , " format " : " celsius " , " location " : " Paris, France " }}<|eot_id|> '
)
2024-09-30 03:15:09 -06:00
assert count == 28
2024-10-10 07:28:25 -06:00
assert last_response == response_snapshot
2024-04-16 07:02:46 -06:00
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information (
flash_llama_grammar_tools , response_snapshot
) :
responses = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
2024-08-26 18:19:38 -06:00
seed = 24 ,
2024-04-16 07:02:46 -06:00
tools = tools ,
tool_choice = " auto " ,
messages = [
2024-10-03 15:56:40 -06:00
{
" role " : " system " ,
2024-10-10 07:28:25 -06:00
" content " : " You ' re a helpful assistant! Answer the users question best you can. " ,
2024-10-03 15:56:40 -06:00
} ,
2024-04-16 07:02:46 -06:00
{
" role " : " user " ,
2024-10-10 07:28:25 -06:00
" content " : " Who are you? " ,
2024-04-16 07:02:46 -06:00
} ,
] ,
stream = False ,
)
2024-10-10 07:28:25 -06:00
assert responses . choices [ 0 ] . message . tool_calls is None
assert responses . choices [ 0 ] . message . content == " I am an AI assistant "
assert responses == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream (
flash_llama_grammar_tools , response_snapshot
) :
responses = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 24 ,
tools = tools ,
tool_choice = " auto " ,
messages = [
{
" role " : " system " ,
" content " : " You ' re a helpful assistant! Answer the users question best you can. " ,
} ,
{
" role " : " user " ,
" content " : " Who are you? " ,
} ,
] ,
stream = True ,
)
count = 0
content_generated = " "
last_response = None
async for response in responses :
count + = 1
content_generated + = response . choices [ 0 ] . delta . content
last_response = response
assert response . choices [ 0 ] . delta . tool_calls is None
assert count == 5
assert content_generated == " I am an AI assistant "
assert last_response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream (
flash_llama_grammar_tools , response_snapshot
) :
responses = await flash_llama_grammar_tools . chat (
max_tokens = 100 ,
seed = 24 ,
tools = tools ,
tool_choice = " auto " ,
messages = [
{
" role " : " system " ,
" content " : " You ' re a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response. " ,
} ,
{
" role " : " user " ,
" content " : " Tell me a story about 3 sea creatures " ,
} ,
] ,
stream = True ,
)
count = 0
content_generated = " "
last_response = None
async for response in responses :
count + = 1
content_generated + = response . choices [ 0 ] . delta . content
last_response = response
assert response . choices [ 0 ] . delta . tool_calls is None
assert count == 62
2024-10-03 15:56:40 -06:00
assert (
2024-10-10 07:28:25 -06:00
content_generated
== " Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans "
2024-10-03 15:56:40 -06:00
)
2024-10-10 07:28:25 -06:00
assert last_response == response_snapshot