fix(router): fix openapi and add jsonschema validation (#1578)

This commit is contained in:
OlivierDehaene 2024-02-21 11:05:32 +01:00 committed by GitHub
parent c9f4c1af31
commit fa8a8e05af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 312 additions and 122 deletions

View File

@ -41,6 +41,10 @@ jobs:
components: rustfmt, clippy
- name: Install Protoc
uses: arduino/setup-protoc@v1
- name: Clean unused files
run: |
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install sccache
run: |
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache

157
Cargo.lock generated
View File

@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"serde",
"version_check",
"zerocopy",
]
@ -265,6 +267,21 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "1.3.2"
@ -716,6 +733,16 @@ dependencies = [
"cc",
]
[[package]]
name = "fancy-regex"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
dependencies = [
"bit-set",
"regex",
]
[[package]]
name = "fastrand"
version = "2.0.1"
@ -780,6 +807,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fraction"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
dependencies = [
"lazy_static",
"num",
]
[[package]]
name = "futures"
version = "0.3.30"
@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@ -1181,6 +1220,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "iso8601"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
dependencies = [
"nom",
]
[[package]]
name = "itertools"
version = "0.10.5"
@ -1223,6 +1271,36 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
dependencies = [
"ahash",
"anyhow",
"base64 0.21.7",
"bytecount",
"clap",
"fancy-regex",
"fraction",
"getrandom",
"iso8601",
"itoa",
"memchr",
"num-cmp",
"once_cell",
"parking_lot",
"percent-encoding",
"regex",
"reqwest",
"serde",
"serde_json",
"time",
"url",
"uuid",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -1574,12 +1652,84 @@ dependencies = [
"winapi",
]
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-cmp"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
[[package]]
name = "num-complex"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6"
dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.18"
@ -2874,6 +3024,7 @@ dependencies = [
"futures-util",
"hf-hub",
"init-tracing-opentelemetry",
"jsonschema",
"metrics",
"metrics-exporter-prometheus",
"minijinja",
@ -3530,6 +3681,12 @@ dependencies = [
"zip",
]
[[package]]
name = "uuid"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a"
[[package]]
name = "valuable"
version = "0.1.0"

View File

@ -1022,6 +1022,57 @@
}
}
},
"GrammarType": {
"oneOf": [
{
"type": "object",
"required": [
"type",
"value"
],
"properties": {
"type": {
"type": "string",
"enum": [
"json"
]
},
"value": {
"type": "string",
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.",
"example": {
"properties": {
"location": {
"type": "string"
}
}
}
}
}
},
{
"type": "object",
"required": [
"type",
"value"
],
"properties": {
"type": {
"type": "string",
"enum": [
"regex"
]
},
"value": {
"type": "string"
}
}
}
],
"discriminator": {
"propertyName": "type"
}
},
"Info": {
"type": "object",
"required": [

View File

@ -135,129 +135,129 @@
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0067749023,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.0069274902,
"logprob": -0.08862305,
"special": false,
"text": "h"
},
{
"id": 20838,
"logprob": -0.19580078,
"id": 711,
"logprob": -0.66259766,
"special": false,
"text": "obb"
"text": "ob"
},
{
"id": 29891,
"logprob": -2.2649765e-06,
"id": 1609,
"logprob": -5.51939e-05,
"special": false,
"text": "y"
"text": "by"
},
{
"id": 4710,
"logprob": -0.32080078,
"logprob": -0.23120117,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.1035156,
"logprob": -2.3730469,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.020767212,
"logprob": -0.032104492,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.6010742,
"logprob": -0.22021484,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.06726074,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.003501892,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0045661926,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.12512207,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.009552002,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.00042438507,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.11651611,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.57666016,
"logprob": -0.29736328,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.0061073303,
"logprob": -0.003030777,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.45703125,
"logprob": -0.3774414,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0002872944,
"logprob": -0.0003130436,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021018982,
"logprob": -0.0021514893,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.08996582,
"logprob": -0.071899414,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.021697998,
"logprob": -0.018997192,
"special": false,
"text": "}"
},
@ -270,5 +270,5 @@
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
}

View File

@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
"What is Deep Learning?",
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load(
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot
):

View File

@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
response = await flash_medusa.generate(
"What is Deep Learning?",
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
responses = await generate_load(
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
response = await flash_mistral.generate(
"Test request",
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
responses = await generate_load(
flash_mistral, "Test request", max_new_tokens=10, n=4

View File

@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate(
"Test request",
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)

View File

@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
response = await flash_starcoder_gptq.generate(
"def geometric_mean(L: List[float]):",
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_default_params(
flash_starcoder_gptq, generous_response_snapshot
):
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_starcoder_gptq_load(
flash_starcoder_gptq, generate_load, generous_response_snapshot
):

View File

@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):

View File

@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"What is Deep Learning?", max_new_tokens=10
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
response = await fused_kernel_mamba.generate(
"blue, red, yellow, ",
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_mamba_load(
fused_kernel_mamba, generate_load, generous_response_snapshot
):

View File

@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28"
hf-hub = { version = "0.3.0", features = ["tokio"] }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
nohash-hasher = "0.2.0"

View File

@ -64,39 +64,16 @@ impl HubTokenizerConfig {
}
}
mod json_object_or_string_to_string {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
/// A string that represents a [JSON Schema](https://json-schema.org/).
///
/// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions.
#[serde(rename = "json")]
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
Json(serde_json::Value),
#[serde(rename = "regex")]
Regex(String),
}

View File

@ -893,6 +893,7 @@ pub async fn run(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
ChatCompletionChoice,

View File

@ -1,7 +1,9 @@
/// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest, GrammarType};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
@ -313,8 +315,29 @@ impl Validation {
return Err(ValidationError::Grammar);
}
match grammar {
// currently both are handled the same way since compilation is done in Python
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
GrammarType::Json(json) => {
let json = match json {
// if value is a string, we need to parse it again to make sure its
// a valid json
Value::String(s) => serde_json::from_str(&s)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
Value::Object(_) => Ok(json),
_ => Err(ValidationError::Grammar),
}?;
// Check if the json is a valid JSONSchema
JSONSchema::options()
.with_draft(Draft::Draft202012)
.compile(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
(
// Serialize json to string
serde_json::to_string(&json)
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
ProtoGrammarType::Json.into(),
)
}
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
}
}
@ -486,6 +509,8 @@ pub enum ValidationError {
Tokenizer(String),
#[error("grammar is not supported")]
Grammar,
#[error("grammar is not valid: {0}")]
InvalidGrammar(String),
}
#[cfg(test)]

View File

@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
scores = scores.view(B, S, -1)
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
for j in range(S):
_scores = scores[:, j]
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids