2024-02-15 02:28:10 -07:00
|
|
|
import pytest
|
|
|
|
import json
|
|
|
|
|
|
|
|
from text_generation.types import GrammarType
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
2024-03-01 10:22:01 -07:00
|
|
|
def non_flash_llama_grammar_handle(launcher):
|
2024-02-15 02:28:10 -07:00
|
|
|
with launcher(
|
2024-03-01 10:22:01 -07:00
|
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
|
|
num_shard=1,
|
|
|
|
disable_grammar_support=False,
|
|
|
|
use_flash_attention=False,
|
2024-02-15 02:28:10 -07:00
|
|
|
) as handle:
|
|
|
|
yield handle
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
2024-03-01 10:22:01 -07:00
|
|
|
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
|
|
|
await non_flash_llama_grammar_handle.health(300)
|
|
|
|
return non_flash_llama_grammar_handle.client
|
2024-02-15 02:28:10 -07:00
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
@pytest.mark.release
|
2024-03-01 10:22:01 -07:00
|
|
|
@pytest.mark.skip
|
2024-02-15 02:28:10 -07:00
|
|
|
@pytest.mark.asyncio
|
2024-03-01 10:22:01 -07:00
|
|
|
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
|
|
|
response = await non_flash_llama_grammar.generate(
|
2024-02-15 02:28:10 -07:00
|
|
|
"info: david holtz like trees and has two cats. ",
|
|
|
|
max_new_tokens=100,
|
|
|
|
decoder_input_details=True,
|
|
|
|
seed=0,
|
|
|
|
grammar={
|
2024-03-01 10:22:01 -07:00
|
|
|
"type": GrammarType.Json,
|
2024-02-15 02:28:10 -07:00
|
|
|
"value": json.dumps(
|
|
|
|
{
|
|
|
|
"type": "object",
|
|
|
|
"$id": "https://example.com/person.schema.json",
|
|
|
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
|
|
"title": "Person",
|
|
|
|
"properties": {
|
|
|
|
"firstName": {
|
|
|
|
"type": "string",
|
|
|
|
"description": "The person'''s first name.",
|
|
|
|
},
|
|
|
|
"lastName": {
|
|
|
|
"type": "string",
|
|
|
|
"description": "The person'''s last name.",
|
|
|
|
},
|
|
|
|
"hobby": {
|
|
|
|
"description": "The person'''s hobby.",
|
|
|
|
"type": "string",
|
|
|
|
},
|
|
|
|
"numCats": {
|
|
|
|
"description": "The number of cats the person has.",
|
|
|
|
"type": "integer",
|
|
|
|
"minimum": 0,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
|
|
|
}
|
|
|
|
),
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
assert response.details.generated_tokens == 30
|
|
|
|
assert (
|
|
|
|
response.generated_text
|
2024-02-21 03:05:32 -07:00
|
|
|
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
2024-02-15 02:28:10 -07:00
|
|
|
)
|
|
|
|
assert response == response_snapshot
|