diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5b19eb8c..29ff6d45 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -41,6 +41,10 @@ jobs: components: rustfmt, clippy - name: Install Protoc uses: arduino/setup-protoc@v1 + - name: Clean unused files + run: | + sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android + sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET - name: Install sccache run: | curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache diff --git a/Cargo.lock b/Cargo.lock index f8151731..ccccdf3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", + "getrandom", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -265,6 +267,21 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -716,6 +733,16 @@ dependencies = [ "cc", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.0.1" @@ -780,6 +807,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "futures" version = "0.3.30" @@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1181,6 +1220,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "iso8601" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" +dependencies = [ + "nom", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1223,6 +1271,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonschema" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +dependencies = [ + "ahash", + "anyhow", + "base64 0.21.7", + "bytecount", + "clap", + "fancy-regex", + "fraction", + "getrandom", + "iso8601", + "itoa", + "memchr", + "num-cmp", + "once_cell", + "parking_lot", + "percent-encoding", + "regex", + "reqwest", + "serde", + "serde_json", + "time", + "url", + "uuid", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1574,12 +1652,84 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -2874,6 +3024,7 @@ dependencies = [ "futures-util", "hf-hub", "init-tracing-opentelemetry", + "jsonschema", "metrics", "metrics-exporter-prometheus", "minijinja", @@ -3530,6 +3681,12 @@ dependencies = [ "zip", ] +[[package]] +name = "uuid" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" + [[package]] name = "valuable" version = "0.1.0" diff --git a/docs/openapi.json b/docs/openapi.json index d72d32e9..fad01aec 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -1022,6 +1022,57 @@ } } }, + "GrammarType": { + "oneOf": [ + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "json" + ] + }, + "value": { + "type": "string", + "description": "A string that represents a [JSON Schema](https://json-schema.org/).\n\nJSON Schema is a declarative language that allows to annotate JSON documents\nwith types and descriptions.", + "example": { + "properties": { + "location": { + "type": "string" + } + } + } + } + } + }, + { + "type": "object", + "required": [ + "type", + "value" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "regex" + ] + }, + "value": { + "type": "string" + } + } + } + ], + "discriminator": { + "propertyName": "type" + } + }, "Info": { "type": "object", "required": [ diff --git a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json index 7b12b158..d7fb620d 100644 --- a/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json +++ b/integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json @@ -135,129 +135,129 @@ "special": false, "text": "\",\"" }, - { - "id": 4230, - "logprob": -0.020492554, - "special": false, - "text": "last" - }, - { - "id": 1170, - "logprob": -0.0013818741, - "special": false, - "text": "Name" - }, - { - "id": 4710, - "logprob": -0.0067749023, - "special": false, - "text": "\":\"" - }, - { - "id": 29950, - "logprob": -0.11578369, - "special": false, - "text": "H" - }, - { - "id": 14339, - "logprob": -0.004131317, - "special": false, - "text": "olt" - }, - { - "id": 29920, - "logprob": -0.0033359528, - "special": false, - "text": "z" - }, - { - "id": 3284, - "logprob": -0.20471191, - "special": false, - "text": "\",\"" - }, { "id": 29882, - "logprob": -0.0069274902, + "logprob": -0.08862305, "special": false, "text": "h" }, { - "id": 20838, - "logprob": -0.19580078, + "id": 711, + "logprob": -0.66259766, "special": false, - "text": "obb" + "text": "ob" }, { - "id": 29891, - "logprob": -2.2649765e-06, + "id": 1609, + "logprob": -5.51939e-05, "special": false, - "text": "y" + "text": "by" }, { "id": 4710, - "logprob": -0.32080078, + "logprob": -0.23120117, "special": false, "text": "\":\"" }, { "id": 29911, - "logprob": -2.1035156, + "logprob": -2.3730469, "special": false, "text": "T" }, { "id": 11003, - "logprob": -0.020767212, + "logprob": -0.032104492, "special": false, "text": "rees" }, { "id": 3284, - "logprob": -0.6010742, + "logprob": -0.22021484, + "special": false, + "text": "\",\"" + }, + { + "id": 4230, + "logprob": -0.06726074, + "special": false, + "text": "last" + }, + { + "id": 1170, + "logprob": -0.003501892, + "special": false, + "text": "Name" + }, + { + "id": 4710, + "logprob": -0.0045661926, + "special": false, + "text": "\":\"" + }, + { + "id": 29950, + "logprob": -0.12512207, + "special": false, + "text": "H" + }, + { + "id": 14339, + "logprob": -0.009552002, + "special": false, + "text": "olt" + }, + { + "id": 29920, + "logprob": -0.00042438507, + "special": false, + "text": "z" + }, + { + "id": 3284, + "logprob": -0.11651611, "special": false, "text": "\",\"" }, { "id": 29876, - "logprob": -0.57666016, + "logprob": -0.29736328, "special": false, "text": "n" }, { "id": 398, - "logprob": -0.0061073303, + "logprob": -0.003030777, "special": false, "text": "um" }, { "id": 29907, - "logprob": -0.45703125, + "logprob": -0.3774414, "special": false, "text": "C" }, { "id": 1446, - "logprob": -0.0002872944, + "logprob": -0.0003130436, "special": false, "text": "ats" }, { "id": 1115, - "logprob": -0.0021018982, + "logprob": -0.0021514893, "special": false, "text": "\":" }, { "id": 29906, - "logprob": -0.08996582, + "logprob": -0.071899414, "special": false, "text": "2" }, { "id": 29913, - "logprob": -0.021697998, + "logprob": -0.018997192, "special": false, "text": "}" }, @@ -270,5 +270,5 @@ ], "top_tokens": null }, - "generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}" + "generated_text": "{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}" } diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index 62a95f48..ead918c3 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", @@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 1c687fc9..a83614ac 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index a0ce0570..e0cc1039 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_simple(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_all_params(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", @@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): responses = await generate_load( flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index ace3328b..52b51928 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_all_params(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", @@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): responses = await generate_load( flash_mistral, "Test request", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 0987b3a1..9d6ca566 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", @@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 5e448d55..329158b7 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( "def geometric_mean(L: List[float]):", @@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot ): @@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params( @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot ): diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index f068496c..585d0656 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Whats Googles DNS", @@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot) @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "info: david holtz like trees and has two cats. ", @@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): assert response.details.generated_tokens == 30 assert ( response.generated_text - == '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}' + == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}' ) assert response == response_snapshot @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_load( flash_llama_grammar, generate_load, response_snapshot ): @@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load( # this is the same as the above test, but only fires off a single request # this is only to ensure that the parallel and single inference produce the same result @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_single_load_instance( flash_llama_grammar, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 5ec2ec31..bf3701b4 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "What is Deep Learning?", max_new_tokens=10 @@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "blue, red, yellow, ", @@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot ): diff --git a/router/Cargo.toml b/router/Cargo.toml index 7d6dc017..170debda 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -22,6 +22,7 @@ text-generation-client = { path = "client" } clap = { version = "4.4.5", features = ["derive", "env"] } futures = "0.3.28" hf-hub = { version = "0.3.0", features = ["tokio"] } +jsonschema = { version = "0.17.1", features = ["draft202012"] } metrics = "0.21.1" metrics-exporter-prometheus = { version = "0.12.1", features = [] } nohash-hasher = "0.2.0" diff --git a/router/src/lib.rs b/router/src/lib.rs index b7285e65..c6928a5a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -64,39 +64,16 @@ impl HubTokenizerConfig { } } -mod json_object_or_string_to_string { - use serde::{Deserialize, Deserializer}; - use serde_json::Value; - - // A custom deserializer that treats both strings and objects as strings. - // This provides flexibility with input formats for the 'grammar' field. - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - match value { - Value::String(s) => Ok(s), - // Safely handle serialization and return an error if it fails - Value::Object(o) => { - serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) - } - _ => Err(serde::de::Error::custom( - "expected string or object for grammar", - )), - } - } -} - #[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { - #[serde( - rename = "json", - deserialize_with = "json_object_or_string_to_string::deserialize" - )] - Json(String), + /// A string that represents a [JSON Schema](https://json-schema.org/). + /// + /// JSON Schema is a declarative language that allows to annotate JSON documents + /// with types and descriptions. + #[serde(rename = "json")] + #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] + Json(serde_json::Value), #[serde(rename = "regex")] Regex(String), } diff --git a/router/src/server.rs b/router/src/server.rs index 054ba5a2..ebde7133 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -893,6 +893,7 @@ pub async fn run( Info, CompatGenerateRequest, GenerateRequest, + GrammarType, ChatRequest, Message, ChatCompletionChoice, diff --git a/router/src/validation.rs b/router/src/validation.rs index bf85b12f..204dbf92 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -1,7 +1,9 @@ /// Payload validation logic use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; use crate::{GenerateParameters, GenerateRequest, GrammarType}; +use jsonschema::{Draft, JSONSchema}; use rand::{thread_rng, Rng}; +use serde_json::Value; use text_generation_client::{ GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters, }; @@ -313,8 +315,29 @@ impl Validation { return Err(ValidationError::Grammar); } match grammar { - // currently both are handled the same way since compilation is done in Python - GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Json(json) => { + let json = match json { + // if value is a string, we need to parse it again to make sure its + // a valid json + Value::String(s) => serde_json::from_str(&s) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string())), + Value::Object(_) => Ok(json), + _ => Err(ValidationError::Grammar), + }?; + + // Check if the json is a valid JSONSchema + JSONSchema::options() + .with_draft(Draft::Draft202012) + .compile(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + + ( + // Serialize json to string + serde_json::to_string(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, + ProtoGrammarType::Json.into(), + ) + } GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), } } @@ -486,6 +509,8 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, + #[error("grammar is not valid: {0}")] + InvalidGrammar(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 72c6c21c..32789850 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) - mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) for j in range(S): _scores = scores[:, j] @@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: - _scores = warper(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + for warper in self.warpers: + _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids