fix(router): fix openapi and add jsonschema validation (#1578)
This commit is contained in:
parent
c9f4c1af31
commit
fa8a8e05af
|
@ -41,6 +41,10 @@ jobs:
|
||||||
components: rustfmt, clippy
|
components: rustfmt, clippy
|
||||||
- name: Install Protoc
|
- name: Install Protoc
|
||||||
uses: arduino/setup-protoc@v1
|
uses: arduino/setup-protoc@v1
|
||||||
|
- name: Clean unused files
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
|
||||||
|
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||||
- name: Install sccache
|
- name: Install sccache
|
||||||
run: |
|
run: |
|
||||||
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
|
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
|
||||||
|
|
|
@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
|
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
|
"getrandom",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"serde",
|
||||||
"version_check",
|
"version_check",
|
||||||
"zerocopy",
|
"zerocopy",
|
||||||
]
|
]
|
||||||
|
@ -265,6 +267,21 @@ version = "0.21.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-set"
|
||||||
|
version = "0.5.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
||||||
|
dependencies = [
|
||||||
|
"bit-vec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bit-vec"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
|
@ -716,6 +733,16 @@ dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fancy-regex"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
|
||||||
|
dependencies = [
|
||||||
|
"bit-set",
|
||||||
|
"regex",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastrand"
|
name = "fastrand"
|
||||||
version = "2.0.1"
|
version = "2.0.1"
|
||||||
|
@ -780,6 +807,16 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fraction"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
|
||||||
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
|
"num",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
|
@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
|
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
|
"js-sys",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi",
|
"wasi",
|
||||||
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1181,6 +1220,15 @@ version = "2.9.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iso8601"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itertools"
|
name = "itertools"
|
||||||
version = "0.10.5"
|
version = "0.10.5"
|
||||||
|
@ -1223,6 +1271,36 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonschema"
|
||||||
|
version = "0.17.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
"anyhow",
|
||||||
|
"base64 0.21.7",
|
||||||
|
"bytecount",
|
||||||
|
"clap",
|
||||||
|
"fancy-regex",
|
||||||
|
"fraction",
|
||||||
|
"getrandom",
|
||||||
|
"iso8601",
|
||||||
|
"itoa",
|
||||||
|
"memchr",
|
||||||
|
"num-cmp",
|
||||||
|
"once_cell",
|
||||||
|
"parking_lot",
|
||||||
|
"percent-encoding",
|
||||||
|
"regex",
|
||||||
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"time",
|
||||||
|
"url",
|
||||||
|
"uuid",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lazy_static"
|
name = "lazy_static"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
|
@ -1574,12 +1652,84 @@ dependencies = [
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-iter",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-bigint"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-cmp"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-conv"
|
name = "num-conv"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-iter"
|
||||||
|
version = "0.1.44"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-bigint",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-traits"
|
name = "num-traits"
|
||||||
version = "0.2.18"
|
version = "0.2.18"
|
||||||
|
@ -2874,6 +3024,7 @@ dependencies = [
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"init-tracing-opentelemetry",
|
"init-tracing-opentelemetry",
|
||||||
|
"jsonschema",
|
||||||
"metrics",
|
"metrics",
|
||||||
"metrics-exporter-prometheus",
|
"metrics-exporter-prometheus",
|
||||||
"minijinja",
|
"minijinja",
|
||||||
|
@ -3530,6 +3681,12 @@ dependencies = [
|
||||||
"zip",
|
"zip",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uuid"
|
||||||
|
version = "1.7.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "valuable"
|
name = "valuable"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
|
@ -1022,6 +1022,57 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"GrammarType": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"value"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"value": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.",
|
||||||
|
"example": {
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"value"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"regex"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"value": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type"
|
||||||
|
}
|
||||||
|
},
|
||||||
"Info": {
|
"Info": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
|
|
@ -135,129 +135,129 @@
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\",\""
|
"text": "\",\""
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 4230,
|
|
||||||
"logprob": -0.020492554,
|
|
||||||
"special": false,
|
|
||||||
"text": "last"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1170,
|
|
||||||
"logprob": -0.0013818741,
|
|
||||||
"special": false,
|
|
||||||
"text": "Name"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 4710,
|
|
||||||
"logprob": -0.0067749023,
|
|
||||||
"special": false,
|
|
||||||
"text": "\":\""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29950,
|
|
||||||
"logprob": -0.11578369,
|
|
||||||
"special": false,
|
|
||||||
"text": "H"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 14339,
|
|
||||||
"logprob": -0.004131317,
|
|
||||||
"special": false,
|
|
||||||
"text": "olt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 29920,
|
|
||||||
"logprob": -0.0033359528,
|
|
||||||
"special": false,
|
|
||||||
"text": "z"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 3284,
|
|
||||||
"logprob": -0.20471191,
|
|
||||||
"special": false,
|
|
||||||
"text": "\",\""
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 29882,
|
"id": 29882,
|
||||||
"logprob": -0.0069274902,
|
"logprob": -0.08862305,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "h"
|
"text": "h"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 20838,
|
"id": 711,
|
||||||
"logprob": -0.19580078,
|
"logprob": -0.66259766,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "obb"
|
"text": "ob"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29891,
|
"id": 1609,
|
||||||
"logprob": -2.2649765e-06,
|
"logprob": -5.51939e-05,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "y"
|
"text": "by"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4710,
|
"id": 4710,
|
||||||
"logprob": -0.32080078,
|
"logprob": -0.23120117,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\":\""
|
"text": "\":\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29911,
|
"id": 29911,
|
||||||
"logprob": -2.1035156,
|
"logprob": -2.3730469,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "T"
|
"text": "T"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 11003,
|
"id": 11003,
|
||||||
"logprob": -0.020767212,
|
"logprob": -0.032104492,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "rees"
|
"text": "rees"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3284,
|
"id": 3284,
|
||||||
"logprob": -0.6010742,
|
"logprob": -0.22021484,
|
||||||
|
"special": false,
|
||||||
|
"text": "\",\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4230,
|
||||||
|
"logprob": -0.06726074,
|
||||||
|
"special": false,
|
||||||
|
"text": "last"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1170,
|
||||||
|
"logprob": -0.003501892,
|
||||||
|
"special": false,
|
||||||
|
"text": "Name"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 4710,
|
||||||
|
"logprob": -0.0045661926,
|
||||||
|
"special": false,
|
||||||
|
"text": "\":\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29950,
|
||||||
|
"logprob": -0.12512207,
|
||||||
|
"special": false,
|
||||||
|
"text": "H"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 14339,
|
||||||
|
"logprob": -0.009552002,
|
||||||
|
"special": false,
|
||||||
|
"text": "olt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 29920,
|
||||||
|
"logprob": -0.00042438507,
|
||||||
|
"special": false,
|
||||||
|
"text": "z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3284,
|
||||||
|
"logprob": -0.11651611,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\",\""
|
"text": "\",\""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29876,
|
"id": 29876,
|
||||||
"logprob": -0.57666016,
|
"logprob": -0.29736328,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "n"
|
"text": "n"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 398,
|
"id": 398,
|
||||||
"logprob": -0.0061073303,
|
"logprob": -0.003030777,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "um"
|
"text": "um"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29907,
|
"id": 29907,
|
||||||
"logprob": -0.45703125,
|
"logprob": -0.3774414,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "C"
|
"text": "C"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1446,
|
"id": 1446,
|
||||||
"logprob": -0.0002872944,
|
"logprob": -0.0003130436,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "ats"
|
"text": "ats"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1115,
|
"id": 1115,
|
||||||
"logprob": -0.0021018982,
|
"logprob": -0.0021514893,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\":"
|
"text": "\":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29906,
|
"id": 29906,
|
||||||
"logprob": -0.08996582,
|
"logprob": -0.071899414,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "2"
|
"text": "2"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 29913,
|
"id": 29913,
|
||||||
"logprob": -0.021697998,
|
"logprob": -0.018997192,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "}"
|
"text": "}"
|
||||||
},
|
},
|
||||||
|
@ -270,5 +270,5 @@
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
|
"generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}"
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
response = await flash_llama_awq.generate(
|
response = await flash_llama_awq.generate(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
|
@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
|
|
@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
response = await flash_llama_awq_sharded.generate(
|
response = await flash_llama_awq_sharded.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
|
|
|
@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||||
response = await flash_medusa.generate(
|
response = await flash_medusa.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
response = await flash_medusa.generate(
|
response = await flash_medusa.generate(
|
||||||
"What is Deep Learning?",
|
"What is Deep Learning?",
|
||||||
|
@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
|
|
|
@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral(flash_mistral, response_snapshot):
|
async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||||
response = await flash_mistral.generate(
|
response = await flash_mistral.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||||
response = await flash_mistral.generate(
|
response = await flash_mistral.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
|
@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_mistral, "Test request", max_new_tokens=10, n=4
|
flash_mistral, "Test request", max_new_tokens=10, n=4
|
||||||
|
|
|
@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi(flash_phi, response_snapshot):
|
async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
response = await flash_phi.generate(
|
response = await flash_phi.generate(
|
||||||
"Test request",
|
"Test request",
|
||||||
|
@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
|
||||||
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):
|
||||||
response = await flash_starcoder_gptq.generate(
|
response = await flash_starcoder_gptq.generate(
|
||||||
"def geometric_mean(L: List[float]):",
|
"def geometric_mean(L: List[float]):",
|
||||||
|
@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq_default_params(
|
async def test_flash_starcoder_gptq_default_params(
|
||||||
flash_starcoder_gptq, generous_response_snapshot
|
flash_starcoder_gptq, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_starcoder_gptq_load(
|
async def test_flash_starcoder_gptq_load(
|
||||||
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
flash_starcoder_gptq, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
|
|
@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||||
|
@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"Whats Googles DNS",
|
"Whats Googles DNS",
|
||||||
|
@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||||
response = await flash_llama_grammar.generate(
|
response = await flash_llama_grammar.generate(
|
||||||
"info: david holtz like trees and has two cats. ",
|
"info: david holtz like trees and has two cats. ",
|
||||||
|
@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
|
||||||
assert response.details.generated_tokens == 30
|
assert response.details.generated_tokens == 30
|
||||||
assert (
|
assert (
|
||||||
response.generated_text
|
response.generated_text
|
||||||
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
|
== '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
|
||||||
)
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_load(
|
async def test_flash_llama_grammar_load(
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
|
@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
|
||||||
# this is the same as the above test, but only fires off a single request
|
# this is the same as the above test, but only fires off a single request
|
||||||
# this is only to ensure that the parallel and single inference produce the same result
|
# this is only to ensure that the parallel and single inference produce the same result
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_flash_llama_grammar_single_load_instance(
|
async def test_flash_llama_grammar_single_load_instance(
|
||||||
flash_llama_grammar, generate_load, response_snapshot
|
flash_llama_grammar, generate_load, response_snapshot
|
||||||
):
|
):
|
||||||
|
|
|
@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"What is Deep Learning?", max_new_tokens=10
|
"What is Deep Learning?", max_new_tokens=10
|
||||||
|
@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
response = await fused_kernel_mamba.generate(
|
response = await fused_kernel_mamba.generate(
|
||||||
"blue, red, yellow, ",
|
"blue, red, yellow, ",
|
||||||
|
@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
|
||||||
async def test_mamba_load(
|
async def test_mamba_load(
|
||||||
fused_kernel_mamba, generate_load, generous_response_snapshot
|
fused_kernel_mamba, generate_load, generous_response_snapshot
|
||||||
):
|
):
|
||||||
|
|
|
@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
futures = "0.3.28"
|
futures = "0.3.28"
|
||||||
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
hf-hub = { version = "0.3.0", features = ["tokio"] }
|
||||||
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
metrics = "0.21.1"
|
metrics = "0.21.1"
|
||||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
|
|
|
@ -64,39 +64,16 @@ impl HubTokenizerConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod json_object_or_string_to_string {
|
|
||||||
use serde::{Deserialize, Deserializer};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
// A custom deserializer that treats both strings and objects as strings.
|
|
||||||
// This provides flexibility with input formats for the 'grammar' field.
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
match value {
|
|
||||||
Value::String(s) => Ok(s),
|
|
||||||
// Safely handle serialization and return an error if it fails
|
|
||||||
Value::Object(o) => {
|
|
||||||
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
|
||||||
}
|
|
||||||
_ => Err(serde::de::Error::custom(
|
|
||||||
"expected string or object for grammar",
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
#[serde(
|
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||||
rename = "json",
|
///
|
||||||
deserialize_with = "json_object_or_string_to_string::deserialize"
|
/// JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
)]
|
/// with types and descriptions.
|
||||||
Json(String),
|
#[serde(rename = "json")]
|
||||||
|
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
|
||||||
|
Json(serde_json::Value),
|
||||||
#[serde(rename = "regex")]
|
#[serde(rename = "regex")]
|
||||||
Regex(String),
|
Regex(String),
|
||||||
}
|
}
|
||||||
|
|
|
@ -893,6 +893,7 @@ pub async fn run(
|
||||||
Info,
|
Info,
|
||||||
CompatGenerateRequest,
|
CompatGenerateRequest,
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
|
GrammarType,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
Message,
|
Message,
|
||||||
ChatCompletionChoice,
|
ChatCompletionChoice,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||||
|
use jsonschema::{Draft, JSONSchema};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
use serde_json::Value;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
|
@ -313,8 +315,29 @@ impl Validation {
|
||||||
return Err(ValidationError::Grammar);
|
return Err(ValidationError::Grammar);
|
||||||
}
|
}
|
||||||
match grammar {
|
match grammar {
|
||||||
// currently both are handled the same way since compilation is done in Python
|
GrammarType::Json(json) => {
|
||||||
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
|
let json = match json {
|
||||||
|
// if value is a string, we need to parse it again to make sure its
|
||||||
|
// a valid json
|
||||||
|
Value::String(s) => serde_json::from_str(&s)
|
||||||
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string())),
|
||||||
|
Value::Object(_) => Ok(json),
|
||||||
|
_ => Err(ValidationError::Grammar),
|
||||||
|
}?;
|
||||||
|
|
||||||
|
// Check if the json is a valid JSONSchema
|
||||||
|
JSONSchema::options()
|
||||||
|
.with_draft(Draft::Draft202012)
|
||||||
|
.compile(&json)
|
||||||
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
|
(
|
||||||
|
// Serialize json to string
|
||||||
|
serde_json::to_string(&json)
|
||||||
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?,
|
||||||
|
ProtoGrammarType::Json.into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -486,6 +509,8 @@ pub enum ValidationError {
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
#[error("grammar is not supported")]
|
#[error("grammar is not supported")]
|
||||||
Grammar,
|
Grammar,
|
||||||
|
#[error("grammar is not valid: {0}")]
|
||||||
|
InvalidGrammar(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
|
||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
|
|
||||||
|
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
|
@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
|
||||||
_scores = self.repetition_processor(input_ids, _scores)
|
_scores = self.repetition_processor(input_ids, _scores)
|
||||||
if self.frequency_processor is not None:
|
if self.frequency_processor is not None:
|
||||||
_scores = self.frequency_processor(input_ids, _scores)
|
_scores = self.frequency_processor(input_ids, _scores)
|
||||||
for warper in self.warpers:
|
|
||||||
_scores = warper(input_ids, _scores)
|
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
||||||
|
for warper in self.warpers:
|
||||||
|
_scores = warper(input_ids, _scores)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
|
|
Loading…
Reference in New Issue