Outlines guided generation (#1539)

This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
This commit is contained in:
drbh 2024-02-15 04:28:10 -05:00 committed by GitHub
parent 4c2848b24b
commit cef0553d59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1660 additions and 53 deletions

View File

@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event; use crate::event::Event;
use crossterm::ExecutableCommand; use crossterm::ExecutableCommand;
use std::io; use std::io;
use text_generation_client::{NextTokenChooserParameters, ShardedClient}; use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;
@ -45,6 +45,8 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}; };
// Initialize terminal properties // Initialize terminal properties

View File

@ -10,6 +10,7 @@ from text_generation.types import (
Response, Response,
Request, Request,
Parameters, Parameters,
Grammar,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
@ -76,6 +77,7 @@ class Client:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
@ -138,6 +140,7 @@ class Client:
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -169,6 +172,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
@ -227,6 +231,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
@ -326,6 +331,7 @@ class AsyncClient:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
@ -370,6 +376,7 @@ class AsyncClient:
Returns: Returns:
Response: generated response Response: generated response
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
@ -388,6 +395,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
@ -417,6 +425,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
@ -475,6 +484,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)

View File

@ -1,10 +1,24 @@
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List from typing import Optional, List, Union
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
# enum for grammar type
class GrammarType(str, Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar(BaseModel):
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
@ -41,6 +55,8 @@ class Parameters(BaseModel):
decoder_input_details: bool = False decoder_input_details: bool = False
# Return the N most likely tokens at each step # Return the N most likely tokens at each step
top_n_tokens: Optional[int] = None top_n_tokens: Optional[int] = None
# grammar to use for generation
grammar: Optional[Grammar] = None
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
@ -109,6 +125,14 @@ class Parameters(BaseModel):
raise ValidationError("`top_n_tokens` must be strictly positive") raise ValidationError("`top_n_tokens` must be strictly positive")
return v return v
@validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
raise ValidationError("`value` cannot be empty for `regex` grammar")
if v.type == GrammarType.Json and not v.value:
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
@ -157,7 +181,7 @@ class Token(BaseModel):
# Token text # Token text
text: str text: str
# Logprob # Logprob
logprob: float logprob: Optional[float] = None
# Is the token a special token # Is the token a special token
# Can be used to ignore tokens when concatenating # Can be used to ignore tokens when concatenating
special: bool special: bool

View File

@ -378,6 +378,14 @@ Options:
[env: TOKENIZER_CONFIG_PATH=] [env: TOKENIZER_CONFIG_PATH=]
```
## DISABLE_GRAMMAR_SUPPORT
```shell
--disable-grammar-support
Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar
[env: DISABLE_GRAMMAR_SUPPORT=]
``` ```
## ENV ## ENV
```shell ```shell

View File

@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
@ -224,6 +231,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -247,6 +255,8 @@ def launcher(event_loop):
env = os.environ env = os.environ
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize is not None: if quantize is not None:
@ -287,12 +297,15 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"] args = ["--model-id", model_id, "--env"]
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize is not None: if quantize is not None:
@ -370,11 +383,22 @@ def launcher(event_loop):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def generate_load(): def generate_load():
async def generate_load_inner( async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int client: AsyncClient,
prompt: str,
max_new_tokens: int,
n: int,
seed: Optional[int] = None,
grammar: Optional[Grammar] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [
client.generate( client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
seed=seed,
grammar=grammar,
stop_sequences=stop_sequences,
) )
for _ in range(n) for _ in range(n)
] ]

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -13.90625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.328125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0566406,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.5253906,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.7578125,
"special": false,
"text": "I"
},
{
"id": 4966,
"logprob": -1.9033203,
"special": false,
"text": " hope"
},
{
"id": 445,
"logprob": -0.5019531,
"special": false,
"text": " this"
},
{
"id": 6911,
"logprob": -0.21264648,
"special": false,
"text": " helps"
},
{
"id": 29991,
"logprob": -0.5991211,
"special": false,
"text": "!"
},
{
"id": 2803,
"logprob": -0.37475586,
"special": false,
"text": " Let"
},
{
"id": 592,
"logprob": -0.018463135,
"special": false,
"text": " me"
},
{
"id": 1073,
"logprob": -0.0008597374,
"special": false,
"text": " know"
}
],
"top_tokens": null
},
"generated_text": "\n\nI hope this helps! Let me know"
}

View File

@ -0,0 +1,274 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 30,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 5235,
"logprob": -10.0625,
"text": "info"
},
{
"id": 29901,
"logprob": -3.2324219,
"text": ":"
},
{
"id": 13260,
"logprob": -10.625,
"text": "dav"
},
{
"id": 333,
"logprob": -0.08276367,
"text": "id"
},
{
"id": 8753,
"logprob": -7.5273438,
"text": "hol"
},
{
"id": 17559,
"logprob": -3.8476562,
"text": "tz"
},
{
"id": 763,
"logprob": -10.140625,
"text": "like"
},
{
"id": 10697,
"logprob": -10.1953125,
"text": "trees"
},
{
"id": 322,
"logprob": -2.5742188,
"text": "and"
},
{
"id": 756,
"logprob": -7.4882812,
"text": "has"
},
{
"id": 1023,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -0.9995117,
"text": "."
},
{
"id": 29871,
"logprob": -4.2421875,
"text": ""
}
],
"seed": null,
"tokens": [
{
"id": 6377,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0067749023,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.0069274902,
"special": false,
"text": "h"
},
{
"id": 20838,
"logprob": -0.19580078,
"special": false,
"text": "obb"
},
{
"id": 29891,
"logprob": -2.2649765e-06,
"special": false,
"text": "y"
},
{
"id": 4710,
"logprob": -0.32080078,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.1035156,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.020767212,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.6010742,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.57666016,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.0061073303,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.45703125,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0002872944,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021018982,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.08996582,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.021697998,
"special": false,
"text": "}"
},
{
"id": 2,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
}

View File

@ -0,0 +1,478 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.03125,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04244995,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4863281,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32714844,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.2376709,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.01008606,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64160156,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.5,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46557617,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5341797,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022907257,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}
]

View File

@ -0,0 +1,109 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 806,
"logprob": -11.890625,
"text": "Wh"
},
{
"id": 1446,
"logprob": -3.6699219,
"text": "ats"
},
{
"id": 2921,
"logprob": -7.8203125,
"text": "Go"
},
{
"id": 468,
"logprob": -8.0703125,
"text": "og"
},
{
"id": 793,
"logprob": -2.1875,
"text": "les"
},
{
"id": 16332,
"logprob": -9.7109375,
"text": "DNS"
}
],
"seed": null,
"tokens": [
{
"id": 29946,
"logprob": -1.4765625,
"special": false,
"text": "4"
},
{
"id": 29906,
"logprob": -0.9199219,
"special": false,
"text": "2"
},
{
"id": 29889,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -1.1367188,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -1.4648438,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.40722656,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -0.17419434,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.20251465,
"special": false,
"text": "1"
},
{
"id": 29900,
"logprob": -1.5527344,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": -1.3710938,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": "42.1.1.101"
}

View File

@ -0,0 +1,73 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.33666992,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.009979248,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.53759766,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.0008878708,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.002275467,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}

View File

@ -0,0 +1,151 @@
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle):
await flash_llama_grammar_handle.health(300)
return flash_llama_grammar_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
assert response.generated_text == "42.1.1.101"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Json, # "json"
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_grammar,
"name: david. email: ",
max_new_tokens=10,
n=4,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
expected = "123456@gmail.com"
for response in responses:
assert response.generated_text == expected
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):
response = await flash_llama_grammar.generate(
"name: david. email: ",
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30
assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot

View File

@ -382,6 +382,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
/// Disable outlines grammar constrained generation.
/// This is a feature that allows you to generate text that follows a specific grammar.
#[clap(long, env)]
disable_grammar_support: bool,
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
@ -1051,6 +1056,11 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Grammar support
if args.disable_grammar_support {
router_args.push("--disable-grammar-support".to_string());
}
// Tokenizer config path // Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string()); router_args.push("--tokenizer-config-path".to_string());

View File

@ -51,6 +51,12 @@ message ClearCacheRequest {
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
@ -70,6 +76,10 @@ message NextTokenChooserParameters {
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {

View File

@ -128,6 +128,8 @@ impl Client {
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,

View File

@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
Request, StoppingCriteriaParameters, Tokens, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;

View File

@ -1,5 +1,6 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
@ -45,6 +46,8 @@ impl Health {
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,

View File

@ -45,6 +45,43 @@ impl HubTokenizerConfig {
} }
} }
mod json_object_or_string_to_string {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
#[serde(rename = "regex")]
Regex(String),
}
mod token_serde { mod token_serde {
use super::*; use super::*;
use serde::de; use serde::de;
@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde(default)]
pub grammar: Option<GrammarType>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: None,
} }
} }

View File

@ -75,6 +75,8 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
} }
#[tokio::main] #[tokio::main]
@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled, messages_api_enabled,
disable_grammar_support,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -343,7 +343,9 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span; use tracing::info_span;
fn default_entry() -> ( fn default_entry() -> (
@ -354,7 +356,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: String::new(),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
@ -368,6 +370,8 @@ mod tests {
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,

View File

@ -614,6 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None,
}, },
}; };
@ -779,6 +780,7 @@ pub async fn run(
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -840,6 +842,7 @@ pub async fn run(
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
grammar_support,
); );
let generation_health = Arc::new(AtomicBool::new(false)); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = Health::new(client.clone(), generation_health.clone());

View File

@ -1,8 +1,10 @@
/// Payload validation logic /// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest}; use crate::{GenerateParameters, GenerateRequest, GrammarType};
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use thiserror::Error; use thiserror::Error;
use tokenizers::tokenizer::Tokenizer; use tokenizers::tokenizer::Tokenizer;
use tokenizers::TruncationDirection; use tokenizers::TruncationDirection;
@ -19,6 +21,7 @@ pub struct Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
disable_grammar_support: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
@ -32,6 +35,7 @@ impl Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
disable_grammar_support: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
@ -66,6 +70,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
} }
} }
@ -182,6 +187,7 @@ impl Validation {
watermark, watermark,
decoder_input_details, decoder_input_details,
top_n_tokens, top_n_tokens,
grammar,
.. ..
} = request.parameters; } = request.parameters;
@ -292,6 +298,28 @@ impl Validation {
.validate_input(request.inputs, truncate, max_new_tokens) .validate_input(request.inputs, truncate, max_new_tokens)
.await?; .await?;
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
// NOTE: this is currently difficult because we need the tokenizer in Python to build
// the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
// may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
// compiler and use that to build the FSM here.
// Validate grammar and unpack the grammar and type for the proto message
let (grammar, grammar_type) = match grammar {
Some(grammar) => {
// Ensure that grammar is not set if it's not supported
if self.disable_grammar_support {
return Err(ValidationError::Grammar);
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
None => (String::new(), ProtoGrammarType::None.into()),
};
let parameters = NextTokenChooserParameters { let parameters = NextTokenChooserParameters {
temperature, temperature,
repetition_penalty, repetition_penalty,
@ -302,6 +330,8 @@ impl Validation {
do_sample, do_sample,
seed, seed,
watermark, watermark,
grammar,
grammar_type,
}; };
let stopping_parameters = StoppingCriteriaParameters { let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens, max_new_tokens,
@ -453,6 +483,8 @@ pub enum ValidationError {
StopSequence(usize, usize), StopSequence(usize, usize),
#[error("tokenizer error {0}")] #[error("tokenizer error {0}")]
Tokenizer(String), Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
} }
#[cfg(test)] #[cfg(test)]
@ -470,6 +502,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
@ -478,6 +511,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
let max_new_tokens = 10; let max_new_tokens = 10;
@ -498,6 +532,7 @@ mod tests {
let max_top_n_tokens = 4; let max_top_n_tokens = 4;
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let disable_grammar_support = true;
let workers = 1; let workers = 1;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
@ -507,6 +542,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
let max_new_tokens = 10; let max_new_tokens = 10;
@ -528,6 +564,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 6; let max_total_tokens = 6;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
@ -536,6 +573,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
@ -562,6 +600,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
@ -570,6 +609,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {
@ -625,6 +665,7 @@ mod tests {
let max_input_length = 5; let max_input_length = 5;
let max_total_tokens = 106; let max_total_tokens = 106;
let workers = 1; let workers = 1;
let disable_grammar_support = true;
let validation = Validation::new( let validation = Validation::new(
workers, workers,
tokenizer, tokenizer,
@ -633,6 +674,7 @@ mod tests {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
disable_grammar_support,
); );
match validation match validation
.validate(GenerateRequest { .validate(GenerateRequest {

View File

@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines="^0.0.27"
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]

View File

@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
)
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[ padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_keys[:, :, -past_seq_len:, :]
] = past_keys[:, :, -past_seq_len:, :] )
else: else:
# BLOOM case # BLOOM case
padded_past_keys[ padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
start_index:end_index, :, :, -past_seq_len: past_keys[:, :, :, -past_seq_len:]
] = past_keys[:, :, :, -past_seq_len:] )
del past_keys del past_keys
start_index = end_index start_index = end_index
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
padded_past_values[ padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
start_index:end_index, :, -past_seq_len:, : past_values[:, :, -past_seq_len:, :]
] = past_values[:, :, -past_seq_len:, :] )
del past_values del past_values
# Update values # Update values
@ -504,9 +506,11 @@ class CausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None, else None
),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -696,7 +700,7 @@ class CausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
@ -735,6 +739,9 @@ class CausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length

View File

@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)
@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device, device=batches[0].next_token_chooser.device,
tokenizer=batches[0].next_token_chooser.tokenizer,
) )
speculative_ids = ( speculative_ids = (
@ -869,7 +870,11 @@ class FlashCausalLM(Model):
# Try to find an associated cuda graph # Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None: if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward( return self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -1013,9 +1018,9 @@ class FlashCausalLM(Model):
# Copy batch.input_ids to prefill_token_indices # Copy batch.input_ids to prefill_token_indices
if prefill_logprobs: if prefill_logprobs:
if len(batch) > 1: if len(batch) > 1:
prefill_tokens_indices[ prefill_tokens_indices[out_start_index : out_end_index - 1] = (
out_start_index : out_end_index - 1 batch.input_ids[start_index + 1 : start_index + out_length]
] = batch.input_ids[start_index + 1 : start_index + out_length] )
else: else:
# Set prefill_tokens_indices to the correct slice # Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[ prefill_tokens_indices = batch.input_ids[
@ -1028,6 +1033,7 @@ class FlashCausalLM(Model):
cumulative_length += input_length cumulative_length += input_length
# Update values
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.speculative_ids = speculative_ids batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids batch.position_ids = next_position_ids + accepted_ids
@ -1166,7 +1172,7 @@ class FlashCausalLM(Model):
if top_n_tokens > 0: if top_n_tokens > 0:
all_top_tokens = [] all_top_tokens = []
for (top_token_ids, top_token_logprobs) in zip( for top_token_ids, top_token_logprobs in zip(
top_token_ids, top_token_logprobs top_token_ids, top_token_logprobs
): ):
toptoken_texts = self.tokenizer.batch_decode( toptoken_texts = self.tokenizer.batch_decode(
@ -1204,6 +1210,12 @@ class FlashCausalLM(Model):
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
# Update values # Update values
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.input_lengths[i] > batch.max_seqlen:

View File

@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
) )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device next_token_chooser_parameters, dtype, device, tokenizer
) )
start_slots = torch.tensor(start_slots, dtype=torch.int64) start_slots = torch.tensor(start_slots, dtype=torch.int64)

View File

@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )

View File

@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length

View File

@ -124,7 +124,7 @@ class MambaBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -694,6 +694,9 @@ class Mamba(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length

View File

@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
inputs.append(r.inputs) inputs.append(r.inputs)
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
decoder_input_lengths.append(1) decoder_input_lengths.append(1)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
generations.append(generation) generations.append(generation)
# Update values # Update values
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
next_token_id_squeezed.item()
)
batch.decoder_input_ids[i] = next_token_id batch.decoder_input_ids[i] = next_token_id
batch.all_decoder_input_ids[i] = all_decoder_input_ids batch.all_decoder_input_ids[i] = all_decoder_input_ids
batch.input_lengths[i] = input_length batch.input_lengths[i] = input_length

View File

@ -1,8 +1,17 @@
import math import math
import torch import torch
import json
from loguru import logger
from functools import lru_cache from functools import lru_cache
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
from transformers import ( from transformers import (
LogitsWarper, LogitsWarper,
@ -135,9 +144,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability # if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where( score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score) return scores.scatter_add_(1, input_ids, score)
@ -464,3 +471,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
self.processors = new_processors self.processors = new_processors
return self return self
return None return None
class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type, grammar, self.tokenizer
)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_state: int,
):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask[allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
def advance(self, next_token_id, fsm_grammar_state):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsm
)
@staticmethod
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1:
return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@staticmethod
@lru_cache(maxsize=32, typed=True)
def _cached_adapt_tokenizer(tokenizer):
"""Adapt tokenizer to work with the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
start_time = time.time()
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_type):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for i in range(len(grammars)):
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type[i], grammars[i], self.tokenizer
)
self.fsms.append(fsm)
def __call__(
self,
logits: torch.Tensor,
fsm_grammar_states: List[int],
mask: torch.Tensor,
):
mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]):
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0
logits += mask
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
return [
GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
)
for i in range(len(next_token_ids))
]
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
return GrammarLogitProcessor._advance(
next_token_id, fsm_grammar_state, self.fsms[index]
)
def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices)

View File

@ -1,11 +1,13 @@
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor, FrequencyPenaltyLogitsProcessor,
GrammarLogitProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousFrequencyPenaltyLogitsProcessor,
@ -13,6 +15,7 @@ from text_generation_server.utils.logits_process import (
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
HeterogeneousTypicalLogitsWarper, HeterogeneousTypicalLogitsWarper,
HeterogeneousGrammarLogitProcessor,
static_warper, static_warper,
) )
from text_generation_server.utils.watermark import WatermarkLogitsProcessor from text_generation_server.utils.watermark import WatermarkLogitsProcessor
@ -22,16 +25,20 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
class NextTokenChooser: class NextTokenChooser:
def __init__( def __init__(
self, self,
watermark=False, watermark: bool = False,
temperature=1.0, temperature: float = 1.0,
repetition_penalty=1.0, repetition_penalty: float = 1.0,
frequency_penalty=0.0, frequency_penalty: float = 0.0,
top_k=None, top_k: Optional[int] = None,
top_p=None, top_p: Optional[float] = None,
typical_p=None, typical_p: Optional[float] = None,
do_sample=False, do_sample: bool = False,
seed=0, seed: int = 0,
device="cpu", device: str = "cpu",
tokenizer: Optional[PreTrainedTokenizerBase] = None,
grammar: str = "",
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
fsm_grammar_state: int = 0,
): ):
self.watermark_processor = ( self.watermark_processor = (
WatermarkLogitsProcessor(device=device) if watermark else None WatermarkLogitsProcessor(device=device) if watermark else None
@ -46,6 +53,12 @@ class NextTokenChooser:
if frequency_penalty and frequency_penalty != 0.0 if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
self.grammar_processor = (
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
if grammar != ""
else None
)
self.tokenizer = tokenizer
has_warpers = ( has_warpers = (
(temperature is not None and temperature != 1.0) (temperature is not None and temperature != 1.0)
@ -61,7 +74,10 @@ class NextTokenChooser:
self.static_warper = None self.static_warper = None
sampling = do_sample or has_warpers sampling = do_sample or has_warpers
self.choice = Sampling(seed, device) if sampling else Greedy() self.choice = Sampling(seed, device) if sampling else Greedy()
self.fsm_grammar_state = fsm_grammar_state
self.grammar = grammar
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -70,6 +86,8 @@ class NextTokenChooser:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores) scores = self.frequency_processor(input_ids, scores)
if self.grammar_processor is not None:
scores = self.grammar_processor(scores, self.fsm_grammar_state)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
@ -80,11 +98,19 @@ class NextTokenChooser:
return next_id, next_logprob return next_id, next_logprob
def advance_grammar(self, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_state = self.grammar_processor.advance(
next_id, self.fsm_grammar_state
)
return self
@classmethod @classmethod
def from_pb( def from_pb(
cls, cls,
pb: generate_pb2.NextTokenChooserParameters, pb: generate_pb2.NextTokenChooserParameters,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "NextTokenChooser": ) -> "NextTokenChooser":
return NextTokenChooser( return NextTokenChooser(
watermark=pb.watermark, watermark=pb.watermark,
@ -97,6 +123,9 @@ class NextTokenChooser:
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=device, device=device,
tokenizer=tokenizer,
grammar=pb.grammar,
grammar_type=pb.grammar_type,
) )
@ -201,6 +230,10 @@ class HeterogeneousNextTokenChooser:
typical_p: List[float], typical_p: List[float],
do_sample: List[bool], do_sample: List[bool],
seeds: List[int], seeds: List[int],
tokenizer: PreTrainedTokenizerBase,
grammars: List[str],
grammar_types: List[int],
fsm_grammar_states=List[int],
): ):
warpers = [] warpers = []
@ -232,6 +265,14 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.grammar_processor = (
HeterogeneousGrammarLogitProcessor(
tokenizer, device, grammars, grammar_types
)
if any([grammar != "" for grammar in grammars])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
@ -263,6 +304,10 @@ class HeterogeneousNextTokenChooser:
self.do_sample = do_sample self.do_sample = do_sample
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tokenizer = tokenizer
self.fsm_grammar_states = fsm_grammar_states
self.grammars = grammars
self.grammar_types = grammar_types
def __call__( def __call__(
self, self,
@ -283,6 +328,8 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1) scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S): for j in range(S):
_scores = scores[:, j] _scores = scores[:, j]
if self.watermark_processor is not None: if self.watermark_processor is not None:
@ -291,10 +338,10 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None: if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores) _scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
@ -352,6 +399,21 @@ class HeterogeneousNextTokenChooser:
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states, self.grammars
)
self.fsm_grammar_states = other_new_states
return self
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
if self.grammar_processor is not None:
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
)
return self
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices) self.watermark_processor = self.watermark_processor.filter(indices)
@ -362,6 +424,9 @@ class HeterogeneousNextTokenChooser:
if self.frequency_processor is not None: if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices) self.frequency_processor = self.frequency_processor.filter(indices)
if self.grammar_processor is not None:
self.grammar_processor = self.grammar_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
@ -372,6 +437,18 @@ class HeterogeneousNextTokenChooser:
self.seeds = [self.seeds[i] for i in indices] self.seeds = [self.seeds[i] for i in indices]
self.do_sample = [self.do_sample[i] for i in indices] self.do_sample = [self.do_sample[i] for i in indices]
new_grammars = []
new_fsm_grammar_states = []
new_grammar_types = []
for i in indices:
new_grammars.append(self.grammars[i])
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
new_grammar_types.append(self.grammar_types[i])
self.grammars = new_grammars
self.fsm_grammar_states = new_fsm_grammar_states
self.grammar_types = new_grammar_types
if any(self.do_sample): if any(self.do_sample):
self.choice.filter(indices) self.choice.filter(indices)
else: else:
@ -385,6 +462,7 @@ class HeterogeneousNextTokenChooser:
pb: List[generate_pb2.NextTokenChooserParameters], pb: List[generate_pb2.NextTokenChooserParameters],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -398,6 +476,10 @@ class HeterogeneousNextTokenChooser:
seeds=[pb_.seed for pb_ in pb], seeds=[pb_.seed for pb_ in pb],
device=device, device=device,
dtype=dtype, dtype=dtype,
tokenizer=tokenizer,
grammars=[pb_.grammar for pb_ in pb],
grammar_types=[pb_.grammar_type for pb_ in pb],
fsm_grammar_states=[0] * len(pb),
) )