Outlines guided generation (#1539)
This WIP PR starts to add grammar support via outlines, currently this PR supports very simple regex grammars and does not optimize for precompiling or caching grammar fsm's. todo: - [X] add simple outlines guidance to `NextTokenChooser` - [X] update protos for grammar - [X] update generation params API - [X] constrain simple grammar - [ ] support parsing more complex grammar into fsm - [ ] support all outline support grammar types - [ ] explore optimizations to avoid recompiling grammars guided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data-raw '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6, "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+" } }' | jq ``` response ```json { "generated_text": "david@example.com" } ``` unguided request ```bash curl -s 'http://localhost:3000/generate' \ --header 'Content-Type: application/json' \ --data '{ "inputs": "make an email for david: \n", "parameters": { "max_new_tokens": 6 } }' | jq ``` response ```json { "generated_text": " email = 'david" } ```
This commit is contained in:
parent
4c2848b24b
commit
cef0553d59
|
@ -8,7 +8,7 @@ use crate::app::App;
|
||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
use tui::backend::CrosstermBackend;
|
||||||
|
@ -45,6 +45,8 @@ pub async fn run(
|
||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
|
|
|
@ -10,6 +10,7 @@ from text_generation.types import (
|
||||||
Response,
|
Response,
|
||||||
Request,
|
Request,
|
||||||
Parameters,
|
Parameters,
|
||||||
|
Grammar,
|
||||||
)
|
)
|
||||||
from text_generation.errors import parse_error
|
from text_generation.errors import parse_error
|
||||||
|
|
||||||
|
@ -76,6 +77,7 @@ class Client:
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
|
grammar: Optional[Grammar] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text
|
Given a prompt, generate the following text
|
||||||
|
@ -138,6 +140,7 @@ class Client:
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
decoder_input_details=decoder_input_details,
|
decoder_input_details=decoder_input_details,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
@ -169,6 +172,7 @@ class Client:
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
|
grammar: Optional[Grammar] = None,
|
||||||
) -> Iterator[StreamResponse]:
|
) -> Iterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens
|
Given a prompt, generate the following stream of tokens
|
||||||
|
@ -227,6 +231,7 @@ class Client:
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
@ -326,6 +331,7 @@ class AsyncClient:
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
decoder_input_details: bool = False,
|
decoder_input_details: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
|
grammar: Optional[Grammar] = None,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following text asynchronously
|
Given a prompt, generate the following text asynchronously
|
||||||
|
@ -370,6 +376,7 @@ class AsyncClient:
|
||||||
Returns:
|
Returns:
|
||||||
Response: generated response
|
Response: generated response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Validate parameters
|
# Validate parameters
|
||||||
parameters = Parameters(
|
parameters = Parameters(
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
|
@ -388,6 +395,7 @@ class AsyncClient:
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
@ -417,6 +425,7 @@ class AsyncClient:
|
||||||
typical_p: Optional[float] = None,
|
typical_p: Optional[float] = None,
|
||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
top_n_tokens: Optional[int] = None,
|
top_n_tokens: Optional[int] = None,
|
||||||
|
grammar: Optional[Grammar] = None,
|
||||||
) -> AsyncIterator[StreamResponse]:
|
) -> AsyncIterator[StreamResponse]:
|
||||||
"""
|
"""
|
||||||
Given a prompt, generate the following stream of tokens asynchronously
|
Given a prompt, generate the following stream of tokens asynchronously
|
||||||
|
@ -475,6 +484,7 @@ class AsyncClient:
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
request = Request(inputs=prompt, stream=True, parameters=parameters)
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,24 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Union
|
||||||
|
|
||||||
from text_generation.errors import ValidationError
|
from text_generation.errors import ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
# enum for grammar type
|
||||||
|
class GrammarType(str, Enum):
|
||||||
|
Json = "json"
|
||||||
|
Regex = "regex"
|
||||||
|
|
||||||
|
|
||||||
|
# Grammar type and value
|
||||||
|
class Grammar(BaseModel):
|
||||||
|
# Grammar type
|
||||||
|
type: GrammarType
|
||||||
|
# Grammar value
|
||||||
|
value: Union[str, dict]
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
# Activate logits sampling
|
# Activate logits sampling
|
||||||
do_sample: bool = False
|
do_sample: bool = False
|
||||||
|
@ -41,6 +55,8 @@ class Parameters(BaseModel):
|
||||||
decoder_input_details: bool = False
|
decoder_input_details: bool = False
|
||||||
# Return the N most likely tokens at each step
|
# Return the N most likely tokens at each step
|
||||||
top_n_tokens: Optional[int] = None
|
top_n_tokens: Optional[int] = None
|
||||||
|
# grammar to use for generation
|
||||||
|
grammar: Optional[Grammar] = None
|
||||||
|
|
||||||
@validator("best_of")
|
@validator("best_of")
|
||||||
def valid_best_of(cls, field_value, values):
|
def valid_best_of(cls, field_value, values):
|
||||||
|
@ -109,6 +125,14 @@ class Parameters(BaseModel):
|
||||||
raise ValidationError("`top_n_tokens` must be strictly positive")
|
raise ValidationError("`top_n_tokens` must be strictly positive")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@validator("grammar")
|
||||||
|
def valid_grammar(cls, v):
|
||||||
|
if v is not None:
|
||||||
|
if v.type == GrammarType.Regex and not v.value:
|
||||||
|
raise ValidationError("`value` cannot be empty for `regex` grammar")
|
||||||
|
if v.type == GrammarType.Json and not v.value:
|
||||||
|
raise ValidationError("`value` cannot be empty for `json` grammar")
|
||||||
|
return v
|
||||||
|
|
||||||
class Request(BaseModel):
|
class Request(BaseModel):
|
||||||
# Prompt
|
# Prompt
|
||||||
|
@ -157,7 +181,7 @@ class Token(BaseModel):
|
||||||
# Token text
|
# Token text
|
||||||
text: str
|
text: str
|
||||||
# Logprob
|
# Logprob
|
||||||
logprob: float
|
logprob: Optional[float] = None
|
||||||
# Is the token a special token
|
# Is the token a special token
|
||||||
# Can be used to ignore tokens when concatenating
|
# Can be used to ignore tokens when concatenating
|
||||||
special: bool
|
special: bool
|
||||||
|
|
|
@ -378,6 +378,14 @@ Options:
|
||||||
|
|
||||||
[env: TOKENIZER_CONFIG_PATH=]
|
[env: TOKENIZER_CONFIG_PATH=]
|
||||||
|
|
||||||
|
```
|
||||||
|
## DISABLE_GRAMMAR_SUPPORT
|
||||||
|
```shell
|
||||||
|
--disable-grammar-support
|
||||||
|
Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar
|
||||||
|
|
||||||
|
[env: DISABLE_GRAMMAR_SUPPORT=]
|
||||||
|
|
||||||
```
|
```
|
||||||
## ENV
|
## ENV
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
||||||
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
|
||||||
|
|
||||||
from text_generation import AsyncClient
|
from text_generation import AsyncClient
|
||||||
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
|
from text_generation.types import (
|
||||||
|
Response,
|
||||||
|
Details,
|
||||||
|
InputToken,
|
||||||
|
Token,
|
||||||
|
BestOfSequence,
|
||||||
|
Grammar,
|
||||||
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
|
@ -224,6 +231,7 @@ def launcher(event_loop):
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
@ -247,6 +255,8 @@ def launcher(event_loop):
|
||||||
|
|
||||||
env = os.environ
|
env = os.environ
|
||||||
|
|
||||||
|
if disable_grammar_support:
|
||||||
|
args.append("--disable-grammar-support")
|
||||||
if num_shard is not None:
|
if num_shard is not None:
|
||||||
args.extend(["--num-shard", str(num_shard)])
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
if quantize is not None:
|
if quantize is not None:
|
||||||
|
@ -287,12 +297,15 @@ def launcher(event_loop):
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_flash_attention: bool = True,
|
use_flash_attention: bool = True,
|
||||||
|
disable_grammar_support: bool = False,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
args = ["--model-id", model_id, "--env"]
|
args = ["--model-id", model_id, "--env"]
|
||||||
|
|
||||||
|
if disable_grammar_support:
|
||||||
|
args.append("--disable-grammar-support")
|
||||||
if num_shard is not None:
|
if num_shard is not None:
|
||||||
args.extend(["--num-shard", str(num_shard)])
|
args.extend(["--num-shard", str(num_shard)])
|
||||||
if quantize is not None:
|
if quantize is not None:
|
||||||
|
@ -370,11 +383,22 @@ def launcher(event_loop):
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def generate_load():
|
def generate_load():
|
||||||
async def generate_load_inner(
|
async def generate_load_inner(
|
||||||
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
|
client: AsyncClient,
|
||||||
|
prompt: str,
|
||||||
|
max_new_tokens: int,
|
||||||
|
n: int,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
grammar: Optional[Grammar] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
) -> List[Response]:
|
) -> List[Response]:
|
||||||
futures = [
|
futures = [
|
||||||
client.generate(
|
client.generate(
|
||||||
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
|
prompt,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=seed,
|
||||||
|
grammar=grammar,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
)
|
)
|
||||||
for _ in range(n)
|
for _ in range(n)
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4321,
|
||||||
|
"logprob": -13.90625,
|
||||||
|
"text": "Test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2009,
|
||||||
|
"logprob": -12.328125,
|
||||||
|
"text": "request"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -2.0566406,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -1.5253906,
|
||||||
|
"special": false,
|
||||||
|
"text": "\n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29902,
|
||||||
|
"logprob": -2.7578125,
|
||||||
|
"special": false,
|
||||||
|
"text": "I"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4966,
|
||||||
|
"logprob": -1.9033203,
|
||||||
|
"special": false,
|
||||||
|
"text": " hope"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 445,
|
||||||
|
"logprob": -0.5019531,
|
||||||
|
"special": false,
|
||||||
|
"text": " this"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6911,
|
||||||
|
"logprob": -0.21264648,
|
||||||
|
"special": false,
|
||||||
|
"text": " helps"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29991,
|
||||||
|
"logprob": -0.5991211,
|
||||||
|
"special": false,
|
||||||
|
"text": "!"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2803,
|
||||||
|
"logprob": -0.37475586,
|
||||||
|
"special": false,
|
||||||
|
"text": " Let"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 592,
|
||||||
|
"logprob": -0.018463135,
|
||||||
|
"special": false,
|
||||||
|
"text": " me"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1073,
|
||||||
|
"logprob": -0.0008597374,
|
||||||
|
"special": false,
|
||||||
|
"text": " know"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "\n\nI hope this helps! Let me know"
|
||||||
|
}
|
|
@ -0,0 +1,274 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 30,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 5235,
|
||||||
|
"logprob": -10.0625,
|
||||||
|
"text": "info"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.2324219,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -10.625,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.08276367,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8753,
|
||||||
|
"logprob": -7.5273438,
|
||||||
|
"text": "hol"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 17559,
|
||||||
|
"logprob": -3.8476562,
|
||||||
|
"text": "tz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 763,
|
||||||
|
"logprob": -10.140625,
|
||||||
|
"text": "like"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 10697,
|
||||||
|
"logprob": -10.1953125,
|
||||||
|
"text": "trees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 322,
|
||||||
|
"logprob": -2.5742188,
|
||||||
|
"text": "and"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 756,
|
||||||
|
"logprob": -7.4882812,
|
||||||
|
"text": "has"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1023,
|
||||||
|
"logprob": -5.0507812,
|
||||||
|
"text": "two"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 274,
|
||||||
|
"logprob": -5.3164062,
|
||||||
|
"text": "c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.6694336,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.9995117,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29871,
|
||||||
|
"logprob": -4.2421875,
|
||||||
|
"text": ""
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 6377,
|
||||||
|
"logprob": -0.14916992,
|
||||||
|
"special": false,
|
||||||
|
"text": "{\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29888,
|
||||||
|
"logprob": -0.13598633,
|
||||||
|
"special": false,
|
||||||
|
"text": "f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12935,
|
||||||
|
"logprob": -0.017669678,
|
||||||
|
"special": false,
|
||||||
|
"text": "irs"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29873,
|
||||||
|
"logprob": -0.00085639954,
|
||||||
|
"special": false,
|
||||||
|
"text": "t"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.0054016113,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.13549805,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 19504,
|
||||||
|
"logprob": -0.8852539,
|
||||||
|
"special": false,
|
||||||
|
"text": "David"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.16394043,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4230,
|
||||||
|
"logprob": -0.020492554,
|
||||||
|
"special": false,
|
||||||
|
"text": "last"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.0013818741,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.0067749023,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29950,
|
||||||
|
"logprob": -0.11578369,
|
||||||
|
"special": false,
|
||||||
|
"text": "H"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14339,
|
||||||
|
"logprob": -0.004131317,
|
||||||
|
"special": false,
|
||||||
|
"text": "olt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29920,
|
||||||
|
"logprob": -0.0033359528,
|
||||||
|
"special": false,
|
||||||
|
"text": "z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.20471191,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29882,
|
||||||
|
"logprob": -0.0069274902,
|
||||||
|
"special": false,
|
||||||
|
"text": "h"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 20838,
|
||||||
|
"logprob": -0.19580078,
|
||||||
|
"special": false,
|
||||||
|
"text": "obb"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29891,
|
||||||
|
"logprob": -2.2649765e-06,
|
||||||
|
"special": false,
|
||||||
|
"text": "y"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.32080078,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29911,
|
||||||
|
"logprob": -2.1035156,
|
||||||
|
"special": false,
|
||||||
|
"text": "T"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11003,
|
||||||
|
"logprob": -0.020767212,
|
||||||
|
"special": false,
|
||||||
|
"text": "rees"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.6010742,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29876,
|
||||||
|
"logprob": -0.57666016,
|
||||||
|
"special": false,
|
||||||
|
"text": "n"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 398,
|
||||||
|
"logprob": -0.0061073303,
|
||||||
|
"special": false,
|
||||||
|
"text": "um"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29907,
|
||||||
|
"logprob": -0.45703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "C"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -0.0002872944,
|
||||||
|
"special": false,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1115,
|
||||||
|
"logprob": -0.0021018982,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.08996582,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29913,
|
||||||
|
"logprob": -0.021697998,
|
||||||
|
"special": false,
|
||||||
|
"text": "}"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": true,
|
||||||
|
"text": "</s>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
|
||||||
|
}
|
|
@ -0,0 +1,478 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1024,
|
||||||
|
"logprob": -10.578125,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.03125,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -9.171875,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.04244995,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -2.4863281,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4876,
|
||||||
|
"logprob": -10.7890625,
|
||||||
|
"text": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -0.32714844,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -9.4921875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7685547,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.2376709,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.01008606,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64160156,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.5,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46557617,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5341797,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.00088739395,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.0022907257,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1024,
|
||||||
|
"logprob": -10.578125,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.0332031,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -9.171875,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.04257202,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -2.4785156,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4876,
|
||||||
|
"logprob": -10.7890625,
|
||||||
|
"text": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -0.32495117,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -9.4921875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7709961,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.23840332,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.00995636,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64208984,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.4970703,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46533203,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5336914,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.00088739395,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.0022735596,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1024,
|
||||||
|
"logprob": -10.578125,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.0332031,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -9.171875,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.04257202,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -2.4785156,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4876,
|
||||||
|
"logprob": -10.7890625,
|
||||||
|
"text": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -0.32495117,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -9.4921875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7709961,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.23840332,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.00995636,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64208984,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.4970703,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46533203,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5336914,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.00088739395,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.0022735596,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1024,
|
||||||
|
"logprob": -10.578125,
|
||||||
|
"text": "name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -3.0332031,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13260,
|
||||||
|
"logprob": -9.171875,
|
||||||
|
"text": "dav"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 333,
|
||||||
|
"logprob": -0.04257202,
|
||||||
|
"text": "id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -2.4785156,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4876,
|
||||||
|
"logprob": -10.7890625,
|
||||||
|
"text": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29901,
|
||||||
|
"logprob": -0.32495117,
|
||||||
|
"text": ":"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 259,
|
||||||
|
"logprob": -9.4921875,
|
||||||
|
"text": " "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7709961,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.23840332,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.00995636,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64208984,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.4970703,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46533203,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5336914,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.5361328,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.00088739395,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.0022735596,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,109 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"logprob": null,
|
||||||
|
"text": "<s>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 806,
|
||||||
|
"logprob": -11.890625,
|
||||||
|
"text": "Wh"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1446,
|
||||||
|
"logprob": -3.6699219,
|
||||||
|
"text": "ats"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2921,
|
||||||
|
"logprob": -7.8203125,
|
||||||
|
"text": "Go"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 468,
|
||||||
|
"logprob": -8.0703125,
|
||||||
|
"text": "og"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 793,
|
||||||
|
"logprob": -2.1875,
|
||||||
|
"text": "les"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 16332,
|
||||||
|
"logprob": -9.7109375,
|
||||||
|
"text": "DNS"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -1.4765625,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.9199219,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": 0.0,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -1.1367188,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -1.4648438,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.40722656,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.17419434,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.20251465,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29900,
|
||||||
|
"logprob": -1.5527344,
|
||||||
|
"special": false,
|
||||||
|
"text": "0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -1.3710938,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "42.1.1.101"
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_tokens": 10,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 29896,
|
||||||
|
"logprob": -0.7685547,
|
||||||
|
"special": false,
|
||||||
|
"text": "1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29906,
|
||||||
|
"logprob": -0.33666992,
|
||||||
|
"special": false,
|
||||||
|
"text": "2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29941,
|
||||||
|
"logprob": -0.009979248,
|
||||||
|
"special": false,
|
||||||
|
"text": "3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29946,
|
||||||
|
"logprob": -0.64208984,
|
||||||
|
"special": false,
|
||||||
|
"text": "4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29945,
|
||||||
|
"logprob": -0.4970703,
|
||||||
|
"special": false,
|
||||||
|
"text": "5"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29953,
|
||||||
|
"logprob": -0.46533203,
|
||||||
|
"special": false,
|
||||||
|
"text": "6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29992,
|
||||||
|
"logprob": -0.5336914,
|
||||||
|
"special": false,
|
||||||
|
"text": "@"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 21980,
|
||||||
|
"logprob": -0.53759766,
|
||||||
|
"special": false,
|
||||||
|
"text": "gmail"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29889,
|
||||||
|
"logprob": -0.0008878708,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 510,
|
||||||
|
"logprob": -0.002275467,
|
||||||
|
"special": false,
|
||||||
|
"text": "com"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": "123456@gmail.com"
|
||||||
|
}
|
|
@ -0,0 +1,151 @@
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from text_generation.types import GrammarType
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_grammar_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||||
|
await flash_llama_grammar_handle.health(300)
|
||||||
|
return flash_llama_grammar_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"Whats Googles DNS",
|
||||||
|
max_new_tokens=10,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 10
|
||||||
|
assert response.generated_text == "42.1.1.101"
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"info: david holtz like trees and has two cats. ",
|
||||||
|
max_new_tokens=100,
|
||||||
|
decoder_input_details=True,
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Json, # "json"
|
||||||
|
"value": json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"$id": "https://example.com/person.schema.json",
|
||||||
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||||
|
"title": "Person",
|
||||||
|
"properties": {
|
||||||
|
"firstName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s first name.",
|
||||||
|
},
|
||||||
|
"lastName": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The person'''s last name.",
|
||||||
|
},
|
||||||
|
"hobby": {
|
||||||
|
"description": "The person'''s hobby.",
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
"numCats": {
|
||||||
|
"description": "The number of cats the person has.",
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["firstName", "lastName", "hobby", "numCats"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.details.generated_tokens == 30
|
||||||
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_load(
|
||||||
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
responses = await generate_load(
|
||||||
|
flash_llama_grammar,
|
||||||
|
"name: david. email: ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
n=4,
|
||||||
|
stop_sequences=[".com"],
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(responses) == 4
|
||||||
|
|
||||||
|
expected = "123456@gmail.com"
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
assert response.generated_text == expected
|
||||||
|
|
||||||
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# this is the same as the above test, but only fires off a single request
|
||||||
|
# this is only to ensure that the parallel and single inference produce the same result
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_llama_grammar_single_load_instance(
|
||||||
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
|
):
|
||||||
|
response = await flash_llama_grammar.generate(
|
||||||
|
"name: david. email: ",
|
||||||
|
max_new_tokens=10,
|
||||||
|
stop_sequences=[".com"],
|
||||||
|
seed=0,
|
||||||
|
grammar={
|
||||||
|
"type": GrammarType.Regex, # "regex"
|
||||||
|
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert response.details.generated_tokens == 30
|
||||||
|
assert response.generated_text == "123456@gmail.com"
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
|
@ -382,6 +382,11 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
tokenizer_config_path: Option<String>,
|
tokenizer_config_path: Option<String>,
|
||||||
|
|
||||||
|
/// Disable outlines grammar constrained generation.
|
||||||
|
/// This is a feature that allows you to generate text that follows a specific grammar.
|
||||||
|
#[clap(long, env)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
|
@ -1051,6 +1056,11 @@ fn spawn_webserver(
|
||||||
args.model_id,
|
args.model_id,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Grammar support
|
||||||
|
if args.disable_grammar_support {
|
||||||
|
router_args.push("--disable-grammar-support".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
// Tokenizer config path
|
// Tokenizer config path
|
||||||
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
|
||||||
router_args.push("--tokenizer-config-path".to_string());
|
router_args.push("--tokenizer-config-path".to_string());
|
||||||
|
|
|
@ -51,6 +51,12 @@ message ClearCacheRequest {
|
||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
enum GrammarType {
|
||||||
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
/// exponential scaling output probability distribution
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
|
@ -70,6 +76,10 @@ message NextTokenChooserParameters {
|
||||||
float frequency_penalty = 9;
|
float frequency_penalty = 9;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
|
/// grammar (applied if not empty)
|
||||||
|
string grammar = 10;
|
||||||
|
/// grammar type
|
||||||
|
GrammarType grammar_type = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
|
|
@ -128,6 +128,8 @@ impl Client {
|
||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
|
|
@ -9,8 +9,8 @@ pub use client::Client;
|
||||||
pub use pb::generate::v2::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
Request, StoppingCriteriaParameters, Tokens,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::GrammarType as ProtoGrammarType;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
@ -45,6 +46,8 @@ impl Health {
|
||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
|
|
@ -45,6 +45,43 @@ impl HubTokenizerConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod json_object_or_string_to_string {
|
||||||
|
use serde::{Deserialize, Deserializer};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// A custom deserializer that treats both strings and objects as strings.
|
||||||
|
// This provides flexibility with input formats for the 'grammar' field.
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(s),
|
||||||
|
// Safely handle serialization and return an error if it fails
|
||||||
|
Value::Object(o) => {
|
||||||
|
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
||||||
|
}
|
||||||
|
_ => Err(serde::de::Error::custom(
|
||||||
|
"expected string or object for grammar",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type", content = "value")]
|
||||||
|
pub(crate) enum GrammarType {
|
||||||
|
#[serde(
|
||||||
|
rename = "json",
|
||||||
|
deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||||
|
)]
|
||||||
|
Json(String),
|
||||||
|
#[serde(rename = "regex")]
|
||||||
|
Regex(String),
|
||||||
|
}
|
||||||
|
|
||||||
mod token_serde {
|
mod token_serde {
|
||||||
use super::*;
|
use super::*;
|
||||||
use serde::de;
|
use serde::de;
|
||||||
|
@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub grammar: Option<GrammarType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
|
@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters {
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
grammar: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,8 @@ struct Args {
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok_authtoken,
|
ngrok_authtoken,
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -343,7 +343,9 @@ enum QueueCommand {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{
|
||||||
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
|
@ -354,7 +356,7 @@ mod tests {
|
||||||
|
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: String::new(),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
|
@ -368,6 +370,8 @@ mod tests {
|
||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
|
|
@ -614,6 +614,7 @@ async fn chat_completions(
|
||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
grammar: None,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -779,6 +780,7 @@ pub async fn run(
|
||||||
ngrok_edge: Option<String>,
|
ngrok_edge: Option<String>,
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
|
grammar_support: bool,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -840,6 +842,7 @@ pub async fn run(
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
grammar_support,
|
||||||
);
|
);
|
||||||
let generation_health = Arc::new(AtomicBool::new(false));
|
let generation_health = Arc::new(AtomicBool::new(false));
|
||||||
let health_ext = Health::new(client.clone(), generation_health.clone());
|
let health_ext = Health::new(client.clone(), generation_health.clone());
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{
|
||||||
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
|
@ -19,6 +21,7 @@ pub struct Validation {
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
disable_grammar_support: bool,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
|
||||||
}
|
}
|
||||||
|
@ -32,6 +35,7 @@ impl Validation {
|
||||||
max_top_n_tokens: u32,
|
max_top_n_tokens: u32,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
|
disable_grammar_support: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
|
@ -66,6 +70,7 @@ impl Validation {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,6 +187,7 @@ impl Validation {
|
||||||
watermark,
|
watermark,
|
||||||
decoder_input_details,
|
decoder_input_details,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
grammar,
|
||||||
..
|
..
|
||||||
} = request.parameters;
|
} = request.parameters;
|
||||||
|
|
||||||
|
@ -292,6 +298,28 @@ impl Validation {
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||||
|
// NOTE: this is currently difficult because we need the tokenizer in Python to build
|
||||||
|
// the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
|
||||||
|
// may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
|
||||||
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
|
let (grammar, grammar_type) = match grammar {
|
||||||
|
Some(grammar) => {
|
||||||
|
// Ensure that grammar is not set if it's not supported
|
||||||
|
if self.disable_grammar_support {
|
||||||
|
return Err(ValidationError::Grammar);
|
||||||
|
}
|
||||||
|
match grammar {
|
||||||
|
// currently both are handled the same way since compilation is done in Python
|
||||||
|
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
|
||||||
|
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => (String::new(), ProtoGrammarType::None.into()),
|
||||||
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
|
@ -302,6 +330,8 @@ impl Validation {
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
|
grammar,
|
||||||
|
grammar_type,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -453,6 +483,8 @@ pub enum ValidationError {
|
||||||
StopSequence(usize, usize),
|
StopSequence(usize, usize),
|
||||||
#[error("tokenizer error {0}")]
|
#[error("tokenizer error {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
|
#[error("grammar is not supported")]
|
||||||
|
Grammar,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -470,6 +502,7 @@ mod tests {
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 6;
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
|
let disable_grammar_support = true;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -478,6 +511,7 @@ mod tests {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
|
@ -498,6 +532,7 @@ mod tests {
|
||||||
let max_top_n_tokens = 4;
|
let max_top_n_tokens = 4;
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 6;
|
let max_total_tokens = 6;
|
||||||
|
let disable_grammar_support = true;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
|
@ -507,6 +542,7 @@ mod tests {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
|
@ -528,6 +564,7 @@ mod tests {
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 6;
|
let max_total_tokens = 6;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
|
let disable_grammar_support = true;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -536,6 +573,7 @@ mod tests {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
|
@ -562,6 +600,7 @@ mod tests {
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 106;
|
let max_total_tokens = 106;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
|
let disable_grammar_support = true;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -570,6 +609,7 @@ mod tests {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
|
@ -625,6 +665,7 @@ mod tests {
|
||||||
let max_input_length = 5;
|
let max_input_length = 5;
|
||||||
let max_total_tokens = 106;
|
let max_total_tokens = 106;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
|
let disable_grammar_support = true;
|
||||||
let validation = Validation::new(
|
let validation = Validation::new(
|
||||||
workers,
|
workers,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
@ -633,6 +674,7 @@ mod tests {
|
||||||
max_top_n_tokens,
|
max_top_n_tokens,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
|
disable_grammar_support,
|
||||||
);
|
);
|
||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
|
|
|
@ -34,6 +34,7 @@ peft = { version = "^0.8.2", optional = true }
|
||||||
torch = { version = "^2.1.1", optional = true }
|
torch = { version = "^2.1.1", optional = true }
|
||||||
scipy = "^1.11.1"
|
scipy = "^1.11.1"
|
||||||
pillow = "^10.0.0"
|
pillow = "^10.0.0"
|
||||||
|
outlines="^0.0.27"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
|
|
|
@ -87,7 +87,9 @@ class CausalLMBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(
|
||||||
|
NextTokenChooser.from_pb(r.parameters, device, tokenizer)
|
||||||
|
)
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -413,14 +415,14 @@ class CausalLMBatch(Batch):
|
||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_keys[:, :, -past_seq_len:, :]
|
||||||
] = past_keys[:, :, -past_seq_len:, :]
|
)
|
||||||
else:
|
else:
|
||||||
# BLOOM case
|
# BLOOM case
|
||||||
padded_past_keys[
|
padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (
|
||||||
start_index:end_index, :, :, -past_seq_len:
|
past_keys[:, :, :, -past_seq_len:]
|
||||||
] = past_keys[:, :, :, -past_seq_len:]
|
)
|
||||||
del past_keys
|
del past_keys
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
@ -438,9 +440,9 @@ class CausalLMBatch(Batch):
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
padded_past_values[
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
start_index:end_index, :, -past_seq_len:, :
|
past_values[:, :, -past_seq_len:, :]
|
||||||
] = past_values[:, :, -past_seq_len:, :]
|
)
|
||||||
del past_values
|
del past_values
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
@ -504,9 +506,11 @@ class CausalLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map=(
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
"auto"
|
||||||
else None,
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None
|
||||||
|
),
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -696,7 +700,7 @@ class CausalLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
top_token_ids, top_token_logprobs
|
top_token_ids, top_token_logprobs
|
||||||
):
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
@ -735,6 +739,9 @@ class CausalLM(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
|
@ -237,7 +237,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
@ -593,6 +593,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
device=batches[0].next_token_chooser.device,
|
device=batches[0].next_token_chooser.device,
|
||||||
|
tokenizer=batches[0].next_token_chooser.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
|
@ -869,7 +870,11 @@ class FlashCausalLM(Model):
|
||||||
# Try to find an associated cuda graph
|
# Try to find an associated cuda graph
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None or batch.speculative_ids is not None:
|
if (
|
||||||
|
cu_seqlen_prefill is not None
|
||||||
|
or cuda_graph is None
|
||||||
|
or batch.speculative_ids is not None
|
||||||
|
):
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -1013,9 +1018,9 @@ class FlashCausalLM(Model):
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[
|
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||||
out_start_index : out_end_index - 1
|
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
] = batch.input_ids[start_index + 1 : start_index + out_length]
|
)
|
||||||
else:
|
else:
|
||||||
# Set prefill_tokens_indices to the correct slice
|
# Set prefill_tokens_indices to the correct slice
|
||||||
prefill_tokens_indices = batch.input_ids[
|
prefill_tokens_indices = batch.input_ids[
|
||||||
|
@ -1028,6 +1033,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
# Update values
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
|
@ -1166,7 +1172,7 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
all_top_tokens = []
|
all_top_tokens = []
|
||||||
for (top_token_ids, top_token_logprobs) in zip(
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
top_token_ids, top_token_logprobs
|
top_token_ids, top_token_logprobs
|
||||||
):
|
):
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
|
@ -1204,6 +1210,12 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
|
# accept each new token for this specific request since we may
|
||||||
|
# have more than one new token per request with speculative decoding
|
||||||
|
for next_token_id in _next_token_ids:
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||||
|
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
|
|
|
@ -192,7 +192,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
|
||||||
)
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
inputs.append(escape_custom_split_sequence(r.inputs))
|
inputs.append(escape_custom_split_sequence(r.inputs))
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
|
|
@ -114,7 +114,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -815,6 +815,9 @@ class IdeficsCausalLM(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
|
@ -124,7 +124,7 @@ class MambaBatch(Batch):
|
||||||
for i, r in enumerate(pb.requests):
|
for i, r in enumerate(pb.requests):
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -694,6 +694,9 @@ class Mamba(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.input_ids[i, 0] = next_token_id
|
batch.input_ids[i, 0] = next_token_id
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
|
|
|
@ -96,7 +96,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
requests_idx_mapping[r.id] = i
|
requests_idx_mapping[r.id] = i
|
||||||
decoder_input_lengths.append(1)
|
decoder_input_lengths.append(1)
|
||||||
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
|
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device, tokenizer))
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
stopping_criteria = StoppingCriteria.from_pb(
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
|
@ -789,6 +789,9 @@ class Seq2SeqLM(Model):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(
|
||||||
|
next_token_id_squeezed.item()
|
||||||
|
)
|
||||||
batch.decoder_input_ids[i] = next_token_id
|
batch.decoder_input_ids[i] = next_token_id
|
||||||
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
batch.all_decoder_input_ids[i] = all_decoder_input_ids
|
||||||
batch.input_lengths[i] = input_length
|
batch.input_lengths[i] = input_length
|
||||||
|
|
|
@ -1,8 +1,17 @@
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import json
|
||||||
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||||
|
|
||||||
|
from outlines.fsm.fsm import RegexFSM
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, DefaultDict
|
||||||
|
import time
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
|
@ -135,9 +144,7 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
score = torch.gather(scores, 1, input_ids)
|
score = torch.gather(scores, 1, input_ids)
|
||||||
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
|
||||||
score = -torch.where(
|
score = -torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||||
score < 0, score * self.penalty, score / self.penalty
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores.scatter_add_(1, input_ids, score)
|
return scores.scatter_add_(1, input_ids, score)
|
||||||
|
|
||||||
|
@ -464,3 +471,132 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
|
||||||
self.processors = new_processors
|
self.processors = new_processors
|
||||||
return self
|
return self
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarLogitProcessor(LogitsProcessor):
|
||||||
|
fsm_state: DefaultDict[int, int]
|
||||||
|
fsm: RegexFSM
|
||||||
|
|
||||||
|
def __init__(self, tokenizer, device, grammar, grammar_type):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
grammar_type, grammar, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
fsm_grammar_state: int,
|
||||||
|
):
|
||||||
|
if fsm_grammar_state == -1 or self.fsm is None:
|
||||||
|
return logits
|
||||||
|
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
|
||||||
|
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
||||||
|
mask[allowed_tokens] = 0
|
||||||
|
biased_scores = logits + mask
|
||||||
|
return biased_scores
|
||||||
|
|
||||||
|
def advance(self, next_token_id, fsm_grammar_state):
|
||||||
|
return GrammarLogitProcessor._advance(
|
||||||
|
next_token_id, fsm_grammar_state, self.fsm
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _advance(next_token_id, fsm_grammar_state, fsm):
|
||||||
|
if fsm_grammar_state == -1:
|
||||||
|
return fsm_grammar_state
|
||||||
|
return fsm.next_state(fsm_grammar_state, next_token_id)
|
||||||
|
|
||||||
|
# TODO: move grammar compilation into the router
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
|
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||||
|
start_time = time.time()
|
||||||
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
|
schema = build_regex_from_object(schema)
|
||||||
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||||
|
pass # schema is already a regex just here for clarity
|
||||||
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
|
return fsm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=32, typed=True)
|
||||||
|
def _cached_adapt_tokenizer(tokenizer):
|
||||||
|
"""Adapt tokenizer to work with the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. In addition we need to handle the missing spaces to
|
||||||
|
Llama's tokenizer to be able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
new_fsms = []
|
||||||
|
for i in indices:
|
||||||
|
new_fsms.append(self.fsms[i])
|
||||||
|
self.fsms = new_fsms
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, tokenizer, device, grammars, grammar_type):
|
||||||
|
self.device = device
|
||||||
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
|
self.fsms = []
|
||||||
|
for i in range(len(grammars)):
|
||||||
|
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
grammar_type[i], grammars[i], self.tokenizer
|
||||||
|
)
|
||||||
|
self.fsms.append(fsm)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
fsm_grammar_states: List[int],
|
||||||
|
mask: torch.Tensor,
|
||||||
|
):
|
||||||
|
mask = torch.full_like(logits, -math.inf)
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
fsm = self.fsms[i]
|
||||||
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
|
continue
|
||||||
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
|
mask[i, allowed_tokens] = 0
|
||||||
|
logits += mask
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||||
|
return [
|
||||||
|
GrammarLogitProcessor._advance(
|
||||||
|
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
|
||||||
|
)
|
||||||
|
for i in range(len(next_token_ids))
|
||||||
|
]
|
||||||
|
|
||||||
|
def advance_at_index(self, next_token_id, fsm_grammar_state, index):
|
||||||
|
return GrammarLogitProcessor._advance(
|
||||||
|
next_token_id, fsm_grammar_state, self.fsms[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, indices):
|
||||||
|
return GrammarLogitProcessor.filter(self, indices)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
FrequencyPenaltyLogitsProcessor,
|
FrequencyPenaltyLogitsProcessor,
|
||||||
|
GrammarLogitProcessor,
|
||||||
HeterogeneousProcessorWrapper,
|
HeterogeneousProcessorWrapper,
|
||||||
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
HeterogeneousRepetitionPenaltyLogitsProcessor,
|
||||||
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
HeterogeneousFrequencyPenaltyLogitsProcessor,
|
||||||
|
@ -13,6 +15,7 @@ from text_generation_server.utils.logits_process import (
|
||||||
HeterogeneousTopKLogitsWarper,
|
HeterogeneousTopKLogitsWarper,
|
||||||
HeterogeneousTopPLogitsWarper,
|
HeterogeneousTopPLogitsWarper,
|
||||||
HeterogeneousTypicalLogitsWarper,
|
HeterogeneousTypicalLogitsWarper,
|
||||||
|
HeterogeneousGrammarLogitProcessor,
|
||||||
static_warper,
|
static_warper,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
|
||||||
|
@ -22,16 +25,20 @@ from transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcess
|
||||||
class NextTokenChooser:
|
class NextTokenChooser:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
watermark=False,
|
watermark: bool = False,
|
||||||
temperature=1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty=1.0,
|
repetition_penalty: float = 1.0,
|
||||||
frequency_penalty=0.0,
|
frequency_penalty: float = 0.0,
|
||||||
top_k=None,
|
top_k: Optional[int] = None,
|
||||||
top_p=None,
|
top_p: Optional[float] = None,
|
||||||
typical_p=None,
|
typical_p: Optional[float] = None,
|
||||||
do_sample=False,
|
do_sample: bool = False,
|
||||||
seed=0,
|
seed: int = 0,
|
||||||
device="cpu",
|
device: str = "cpu",
|
||||||
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
|
grammar: str = "",
|
||||||
|
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||||
|
fsm_grammar_state: int = 0,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
WatermarkLogitsProcessor(device=device) if watermark else None
|
WatermarkLogitsProcessor(device=device) if watermark else None
|
||||||
|
@ -46,6 +53,12 @@ class NextTokenChooser:
|
||||||
if frequency_penalty and frequency_penalty != 0.0
|
if frequency_penalty and frequency_penalty != 0.0
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
self.grammar_processor = (
|
||||||
|
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
|
||||||
|
if grammar != ""
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
|
@ -61,7 +74,10 @@ class NextTokenChooser:
|
||||||
self.static_warper = None
|
self.static_warper = None
|
||||||
|
|
||||||
sampling = do_sample or has_warpers
|
sampling = do_sample or has_warpers
|
||||||
|
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
|
self.fsm_grammar_state = fsm_grammar_state
|
||||||
|
self.grammar = grammar
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
@ -70,6 +86,8 @@ class NextTokenChooser:
|
||||||
scores = self.repetition_processor(input_ids, scores)
|
scores = self.repetition_processor(input_ids, scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
scores = self.frequency_processor(input_ids, scores)
|
scores = self.frequency_processor(input_ids, scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
scores = self.grammar_processor(scores, self.fsm_grammar_state)
|
||||||
|
|
||||||
if self.static_warper is None:
|
if self.static_warper is None:
|
||||||
next_logprob = torch.log_softmax(scores, -1)
|
next_logprob = torch.log_softmax(scores, -1)
|
||||||
|
@ -80,11 +98,19 @@ class NextTokenChooser:
|
||||||
|
|
||||||
return next_id, next_logprob
|
return next_id, next_logprob
|
||||||
|
|
||||||
|
def advance_grammar(self, next_id: int):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.fsm_grammar_state = self.grammar_processor.advance(
|
||||||
|
next_id, self.fsm_grammar_state
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.NextTokenChooserParameters,
|
pb: generate_pb2.NextTokenChooserParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "NextTokenChooser":
|
) -> "NextTokenChooser":
|
||||||
return NextTokenChooser(
|
return NextTokenChooser(
|
||||||
watermark=pb.watermark,
|
watermark=pb.watermark,
|
||||||
|
@ -97,6 +123,9 @@ class NextTokenChooser:
|
||||||
do_sample=pb.do_sample,
|
do_sample=pb.do_sample,
|
||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
grammar=pb.grammar,
|
||||||
|
grammar_type=pb.grammar_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,6 +230,10 @@ class HeterogeneousNextTokenChooser:
|
||||||
typical_p: List[float],
|
typical_p: List[float],
|
||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
grammars: List[str],
|
||||||
|
grammar_types: List[int],
|
||||||
|
fsm_grammar_states=List[int],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
|
@ -232,6 +265,14 @@ class HeterogeneousNextTokenChooser:
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.grammar_processor = (
|
||||||
|
HeterogeneousGrammarLogitProcessor(
|
||||||
|
tokenizer, device, grammars, grammar_types
|
||||||
|
)
|
||||||
|
if any([grammar != "" for grammar in grammars])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if any([x != 1.0 for x in temperature]):
|
if any([x != 1.0 for x in temperature]):
|
||||||
do_sample = [
|
do_sample = [
|
||||||
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
|
||||||
|
@ -263,6 +304,10 @@ class HeterogeneousNextTokenChooser:
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
|
self.grammars = grammars
|
||||||
|
self.grammar_types = grammar_types
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -283,6 +328,8 @@ class HeterogeneousNextTokenChooser:
|
||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
|
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
|
||||||
|
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
@ -291,10 +338,10 @@ class HeterogeneousNextTokenChooser:
|
||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
|
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
|
@ -352,6 +399,21 @@ class HeterogeneousNextTokenChooser:
|
||||||
|
|
||||||
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
|
def advance_grammar(self, next_ids: List[int]):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
other_new_states = self.grammar_processor.advance_batch(
|
||||||
|
next_ids, self.fsm_grammar_states, self.grammars
|
||||||
|
)
|
||||||
|
self.fsm_grammar_states = other_new_states
|
||||||
|
return self
|
||||||
|
|
||||||
|
def advance_grammar_single(self, grammar_state_index: int, next_id: int):
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.fsm_grammar_states[grammar_state_index] = self.grammar_processor.advance_at_index(
|
||||||
|
next_id, self.fsm_grammar_states[grammar_state_index], grammar_state_index
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||||
|
@ -362,6 +424,9 @@ class HeterogeneousNextTokenChooser:
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
self.frequency_processor = self.frequency_processor.filter(indices)
|
self.frequency_processor = self.frequency_processor.filter(indices)
|
||||||
|
|
||||||
|
if self.grammar_processor is not None:
|
||||||
|
self.grammar_processor = self.grammar_processor.filter(indices)
|
||||||
|
|
||||||
filtered_warpers = []
|
filtered_warpers = []
|
||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
filtered_warper = warper.filter(indices)
|
filtered_warper = warper.filter(indices)
|
||||||
|
@ -372,6 +437,18 @@ class HeterogeneousNextTokenChooser:
|
||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
|
new_grammars = []
|
||||||
|
new_fsm_grammar_states = []
|
||||||
|
new_grammar_types = []
|
||||||
|
for i in indices:
|
||||||
|
new_grammars.append(self.grammars[i])
|
||||||
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||||
|
new_grammar_types.append(self.grammar_types[i])
|
||||||
|
|
||||||
|
self.grammars = new_grammars
|
||||||
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
self.grammar_types = new_grammar_types
|
||||||
|
|
||||||
if any(self.do_sample):
|
if any(self.do_sample):
|
||||||
self.choice.filter(indices)
|
self.choice.filter(indices)
|
||||||
else:
|
else:
|
||||||
|
@ -385,6 +462,7 @@ class HeterogeneousNextTokenChooser:
|
||||||
pb: List[generate_pb2.NextTokenChooserParameters],
|
pb: List[generate_pb2.NextTokenChooserParameters],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
) -> "HeterogeneousNextTokenChooser":
|
) -> "HeterogeneousNextTokenChooser":
|
||||||
return HeterogeneousNextTokenChooser(
|
return HeterogeneousNextTokenChooser(
|
||||||
watermark=[pb_.watermark for pb_ in pb],
|
watermark=[pb_.watermark for pb_ in pb],
|
||||||
|
@ -398,6 +476,10 @@ class HeterogeneousNextTokenChooser:
|
||||||
seeds=[pb_.seed for pb_ in pb],
|
seeds=[pb_.seed for pb_ in pb],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
grammars=[pb_.grammar for pb_ in pb],
|
||||||
|
grammar_types=[pb_.grammar_type for pb_ in pb],
|
||||||
|
fsm_grammar_states=[0] * len(pb),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue