fix: correctly index into mask when applying grammar (#1618)
This PR fixes how the grammar mask is index when generating text and adds a new test to ensure the grammars work with non flash models
This commit is contained in:
parent
7e08751378
commit
7dbaf9e901
|
@ -1,4 +1,123 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1024,
|
||||
"logprob": -10.578125,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.0332031,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -9.171875,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.04257202,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -2.4785156,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 4876,
|
||||
"logprob": -10.7890625,
|
||||
"text": "email"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.32495117,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -9.4921875,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": -0.7709961,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29941,
|
||||
"logprob": -0.00995636,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -0.64208984,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
},
|
||||
{
|
||||
"id": 29945,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 29953,
|
||||
"logprob": -0.46533203,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 29992,
|
||||
"logprob": -0.5336914,
|
||||
"special": false,
|
||||
"text": "@"
|
||||
},
|
||||
{
|
||||
"id": 21980,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": "gmail"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.00088739395,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.0022735596,
|
||||
"special": false,
|
||||
"text": "com"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
|
@ -355,124 +474,5 @@
|
|||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 1024,
|
||||
"logprob": -10.578125,
|
||||
"text": "name"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.0332031,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -9.171875,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.04257202,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -2.4785156,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 4876,
|
||||
"logprob": -10.7890625,
|
||||
"text": "email"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -0.32495117,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 259,
|
||||
"logprob": -9.4921875,
|
||||
"text": " "
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": -0.7709961,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.33740234,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29941,
|
||||
"logprob": -0.00995636,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 29946,
|
||||
"logprob": -0.64208984,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
},
|
||||
{
|
||||
"id": 29945,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 29953,
|
||||
"logprob": -0.46533203,
|
||||
"special": false,
|
||||
"text": "6"
|
||||
},
|
||||
{
|
||||
"id": 29992,
|
||||
"logprob": -0.5336914,
|
||||
"special": false,
|
||||
"text": "@"
|
||||
},
|
||||
{
|
||||
"id": 21980,
|
||||
"logprob": -0.5361328,
|
||||
"special": false,
|
||||
"text": "gmail"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.00088739395,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.0022735596,
|
||||
"special": false,
|
||||
"text": "com"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "123456@gmail.com"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,274 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 30,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 5235,
|
||||
"logprob": -10.0625,
|
||||
"text": "info"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": -3.2324219,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 13260,
|
||||
"logprob": -10.625,
|
||||
"text": "dav"
|
||||
},
|
||||
{
|
||||
"id": 333,
|
||||
"logprob": -0.08276367,
|
||||
"text": "id"
|
||||
},
|
||||
{
|
||||
"id": 8753,
|
||||
"logprob": -7.5273438,
|
||||
"text": "hol"
|
||||
},
|
||||
{
|
||||
"id": 17559,
|
||||
"logprob": -3.8476562,
|
||||
"text": "tz"
|
||||
},
|
||||
{
|
||||
"id": 763,
|
||||
"logprob": -10.140625,
|
||||
"text": "like"
|
||||
},
|
||||
{
|
||||
"id": 10697,
|
||||
"logprob": -10.1953125,
|
||||
"text": "trees"
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"logprob": -2.5742188,
|
||||
"text": "and"
|
||||
},
|
||||
{
|
||||
"id": 756,
|
||||
"logprob": -7.4882812,
|
||||
"text": "has"
|
||||
},
|
||||
{
|
||||
"id": 1023,
|
||||
"logprob": -5.0507812,
|
||||
"text": "two"
|
||||
},
|
||||
{
|
||||
"id": 274,
|
||||
"logprob": -5.3164062,
|
||||
"text": "c"
|
||||
},
|
||||
{
|
||||
"id": 1446,
|
||||
"logprob": -0.6694336,
|
||||
"text": "ats"
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"logprob": -0.9995117,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 29871,
|
||||
"logprob": -4.2421875,
|
||||
"text": ""
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 6377,
|
||||
"logprob": -0.14916992,
|
||||
"special": false,
|
||||
"text": "{\""
|
||||
},
|
||||
{
|
||||
"id": 29888,
|
||||
"logprob": -0.13598633,
|
||||
"special": false,
|
||||
"text": "f"
|
||||
},
|
||||
{
|
||||
"id": 12935,
|
||||
"logprob": -0.017669678,
|
||||
"special": false,
|
||||
"text": "irs"
|
||||
},
|
||||
{
|
||||
"id": 29873,
|
||||
"logprob": -0.00085639954,
|
||||
"special": false,
|
||||
"text": "t"
|
||||
},
|
||||
{
|
||||
"id": 1170,
|
||||
"logprob": -0.0054016113,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.13549805,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 19504,
|
||||
"logprob": -0.8852539,
|
||||
"special": false,
|
||||
"text": "David"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.16394043,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 29882,
|
||||
"logprob": -0.08862305,
|
||||
"special": false,
|
||||
"text": "h"
|
||||
},
|
||||
{
|
||||
"id": 711,
|
||||
"logprob": -0.66259766,
|
||||
"special": false,
|
||||
"text": "ob"
|
||||
},
|
||||
{
|
||||
"id": 1609,
|
||||
"logprob": -5.51939e-05,
|
||||
"special": false,
|
||||
"text": "by"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.23120117,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 29911,
|
||||
"logprob": -2.3730469,
|
||||
"special": false,
|
||||
"text": "T"
|
||||
},
|
||||
{
|
||||
"id": 11003,
|
||||
"logprob": -0.032104492,
|
||||
"special": false,
|
||||
"text": "rees"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.22021484,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 4230,
|
||||
"logprob": -0.06726074,
|
||||
"special": false,
|
||||
"text": "last"
|
||||
},
|
||||
{
|
||||
"id": 1170,
|
||||
"logprob": -0.003501892,
|
||||
"special": false,
|
||||
"text": "Name"
|
||||
},
|
||||
{
|
||||
"id": 4710,
|
||||
"logprob": -0.0045661926,
|
||||
"special": false,
|
||||
"text": "\":\""
|
||||
},
|
||||
{
|
||||
"id": 29950,
|
||||
"logprob": -0.12512207,
|
||||
"special": false,
|
||||
"text": "H"
|
||||
},
|
||||
{
|
||||
"id": 14339,
|
||||
"logprob": -0.009552002,
|
||||
"special": false,
|
||||
"text": "olt"
|
||||
},
|
||||
{
|
||||
"id": 29920,
|
||||
"logprob": -0.00042438507,
|
||||
"special": false,
|
||||
"text": "z"
|
||||
},
|
||||
{
|
||||
"id": 3284,
|
||||
"logprob": -0.11651611,
|
||||
"special": false,
|
||||
"text": "\",\""
|
||||
},
|
||||
{
|
||||
"id": 29876,
|
||||
"logprob": -0.29736328,
|
||||
"special": false,
|
||||
"text": "n"
|
||||
},
|
||||
{
|
||||
"id": 398,
|
||||
"logprob": -0.003030777,
|
||||
"special": false,
|
||||
"text": "um"
|
||||
},
|
||||
{
|
||||
"id": 29907,
|
||||
"logprob": -0.3774414,
|
||||
"special": false,
|
||||
"text": "C"
|
||||
},
|
||||
{
|
||||
"id": 1446,
|
||||
"logprob": -0.0003130436,
|
||||
"special": false,
|
||||
"text": "ats"
|
||||
},
|
||||
{
|
||||
"id": 1115,
|
||||
"logprob": -0.0021514893,
|
||||
"special": false,
|
||||
"text": "\":"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.071899414,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29913,
|
||||
"logprob": -0.018997192,
|
||||
"special": false,
|
||||
"text": "}"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": 0.0,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
import pytest
|
||||
import json
|
||||
|
||||
from text_generation.types import GrammarType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||
await flash_llama_grammar_handle.health(300)
|
||||
return flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "42.1.1.101"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
max_new_tokens=100,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Json, # "json"
|
||||
"value": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"$id": "https://example.com/person.schema.json",
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"title": "Person",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s first name.",
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The person'''s last name.",
|
||||
},
|
||||
"hobby": {
|
||||
"description": "The person'''s hobby.",
|
||||
"type": "string",
|
||||
},
|
||||
"numCats": {
|
||||
"description": "The number of cats the person has.",
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
},
|
||||
},
|
||||
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||
}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 30
|
||||
assert (
|
||||
response.generated_text
|
||||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_load(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_grammar,
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
expected = "123456@gmail.com"
|
||||
|
||||
for response in responses:
|
||||
assert response.generated_text == expected
|
||||
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
# this is the same as the above test, but only fires off a single request
|
||||
# this is only to ensure that the parallel and single inference produce the same result
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_single_load_instance(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
# assert response.details.generated_tokens == 30
|
||||
assert response.generated_text == "123456@gmail.com"
|
||||
|
||||
assert response == response_snapshot
|
|
@ -5,56 +5,32 @@ from text_generation.types import GrammarType
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_grammar_handle(launcher):
|
||||
def non_flash_llama_grammar_handle(launcher):
|
||||
with launcher(
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
num_shard=1,
|
||||
disable_grammar_support=False,
|
||||
use_flash_attention=False,
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||
await flash_llama_grammar_handle.health(300)
|
||||
return flash_llama_grammar_handle.client
|
||||
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
|
||||
await non_flash_llama_grammar_handle.health(300)
|
||||
return non_flash_llama_grammar_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"Whats Googles DNS",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response.generated_text == "42.1.1.101"
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||
response = await flash_llama_grammar.generate(
|
||||
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
|
||||
response = await non_flash_llama_grammar.generate(
|
||||
"info: david holtz like trees and has two cats. ",
|
||||
max_new_tokens=100,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Json, # "json"
|
||||
"type": GrammarType.Json,
|
||||
"value": json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
|
@ -92,55 +68,3 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
|||
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_load(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
responses = await generate_load(
|
||||
flash_llama_grammar,
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
|
||||
expected = "123456@gmail.com"
|
||||
|
||||
for response in responses:
|
||||
assert response.generated_text == expected
|
||||
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
||||
|
||||
# this is the same as the above test, but only fires off a single request
|
||||
# this is only to ensure that the parallel and single inference produce the same result
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_grammar_single_load_instance(
|
||||
flash_llama_grammar, generate_load, response_snapshot
|
||||
):
|
||||
response = await flash_llama_grammar.generate(
|
||||
"name: david. email: ",
|
||||
max_new_tokens=10,
|
||||
stop_sequences=[".com"],
|
||||
seed=0,
|
||||
grammar={
|
||||
"type": GrammarType.Regex, # "regex"
|
||||
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||
},
|
||||
)
|
||||
|
||||
# assert response.details.generated_tokens == 30
|
||||
assert response.generated_text == "123456@gmail.com"
|
||||
|
||||
assert response == response_snapshot
|
||||
|
|
|
@ -98,6 +98,7 @@ async def test_flash_llama_grammar_no_tools(
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
|
||||
|
@ -134,6 +135,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_auto(
|
||||
|
@ -173,6 +175,7 @@ async def test_flash_llama_grammar_tools_auto(
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_choice(
|
||||
|
@ -208,6 +211,7 @@ async def test_flash_llama_grammar_tools_choice(
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_grammar_tools_stream(
|
||||
|
|
|
@ -491,7 +491,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||
return logits
|
||||
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
mask[allowed_tokens] = 0
|
||||
mask[:, allowed_tokens] = 0
|
||||
biased_scores = logits + mask
|
||||
return biased_scores
|
||||
|
||||
|
|
Loading…
Reference in New Issue