fix: correctly index into mask when applying grammar (#1618)

This PR fixes how the grammar mask is index when generating text and
adds a new test to ensure the grammars work with non flash models
This commit is contained in:
drbh 2024-03-01 12:22:01 -05:00 committed by GitHub
parent 7e08751378
commit 7dbaf9e901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 560 additions and 208 deletions

View File

@ -1,4 +1,123 @@
[ [
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.33740234,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{ {
"details": { "details": {
"best_of_sequences": null, "best_of_sequences": null,
@ -355,124 +474,5 @@
"top_tokens": null "top_tokens": null
}, },
"generated_text": "123456@gmail.com" "generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.33740234,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
} }
] ]

View File

@ -0,0 +1,274 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 30,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 5235,
"logprob": -10.0625,
"text": "info"
},
{
"id": 29901,
"logprob": -3.2324219,
"text": ":"
},
{
"id": 13260,
"logprob": -10.625,
"text": "dav"
},
{
"id": 333,
"logprob": -0.08276367,
"text": "id"
},
{
"id": 8753,
"logprob": -7.5273438,
"text": "hol"
},
{
"id": 17559,
"logprob": -3.8476562,
"text": "tz"
},
{
"id": 763,
"logprob": -10.140625,
"text": "like"
},
{
"id": 10697,
"logprob": -10.1953125,
"text": "trees"
},
{
"id": 322,
"logprob": -2.5742188,
"text": "and"
},
{
"id": 756,
"logprob": -7.4882812,
"text": "has"
},
{
"id": 1023,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -0.9995117,
"text": "."
},
{
"id": 29871,
"logprob": -4.2421875,
"text": ""
}
],
"seed": null,
"tokens": [
{
"id": 6377,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.08862305,
"special": false,
"text": "h"
},
{
"id": 711,
"logprob": -0.66259766,
"special": false,
"text": "ob"
},
{
"id": 1609,
"logprob": -5.51939e-05,
"special": false,
"text": "by"
},
{
"id": 4710,
"logprob": -0.23120117,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.3730469,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.032104492,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.22021484,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.06726074,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.003501892,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0045661926,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.12512207,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.009552002,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.00042438507,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.11651611,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.29736328,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.003030777,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.3774414,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0003130436,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021514893,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.071899414,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.018997192,
"special": false,
"text": "}"
},
{
"id": 2,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
}

View File

@ -0,0 +1,150 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle):
await flash_llama_grammar_handle.health(300)
return flash_llama_grammar_handle.client
@pytest.mark.asyncio
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
assert response.generated_text == "42.1.1.101"
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Json, # "json"
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
)
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_grammar,
"name: david. email: ",
max_new_tokens=10,
n=4,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
expected = "123456@gmail.com"
for response in responses:
assert response.generated_text == expected
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):
response = await flash_llama_grammar.generate(
"name: david. email: ",
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30
assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot

View File

@ -5,56 +5,32 @@ from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher): def non_flash_llama_grammar_handle(launcher):
with launcher( with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
num_shard=1,
disable_grammar_support=False,
use_flash_attention=False,
) as handle: ) as handle:
yield handle yield handle
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle): async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
await flash_llama_grammar_handle.health(300) await non_flash_llama_grammar_handle.health(300)
return flash_llama_grammar_handle.client return non_flash_llama_grammar_handle.client
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate( response = await non_flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
assert response.generated_text == "42.1.1.101"
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ", "info: david holtz like trees and has two cats. ",
max_new_tokens=100, max_new_tokens=100,
decoder_input_details=True, decoder_input_details=True,
seed=0, seed=0,
grammar={ grammar={
"type": GrammarType.Json, # "json" "type": GrammarType.Json,
"value": json.dumps( "value": json.dumps(
{ {
"type": "object", "type": "object",
@ -92,55 +68,3 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
) )
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_grammar,
"name: david. email: ",
max_new_tokens=10,
n=4,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
expected = "123456@gmail.com"
for response in responses:
assert response.generated_text == expected
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):
response = await flash_llama_grammar.generate(
"name: david. email: ",
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30
assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot

View File

@ -98,6 +98,7 @@ async def test_flash_llama_grammar_no_tools(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot): async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
@ -134,6 +135,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_auto( async def test_flash_llama_grammar_tools_auto(
@ -173,6 +175,7 @@ async def test_flash_llama_grammar_tools_auto(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_choice( async def test_flash_llama_grammar_tools_choice(
@ -208,6 +211,7 @@ async def test_flash_llama_grammar_tools_choice(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_grammar_tools_stream( async def test_flash_llama_grammar_tools_stream(

View File

@ -491,7 +491,7 @@ class GrammarLogitProcessor(LogitsProcessor):
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = 0 mask[:, allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores