From 4580ced091007ee110636ac559b78bc7c2b3b017 Mon Sep 17 00:00:00 2001 From: Alex Weston <43505988+aW3st@users.noreply.github.com> Date: Fri, 15 Nov 2024 07:22:52 -0500 Subject: [PATCH] Upgrade outlines to 0.1.1 (#2742) * Upgrade outlines to 0.1.1 * Update for new API * Check if allowed tokens is None --------- Co-authored-by: Nicolas Patry --- server/poetry.lock | 170 +++++++++--------- server/pyproject.toml | 2 +- .../utils/logits_process.py | 18 +- 3 files changed, 93 insertions(+), 97 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index d5b84de3..ad7dab18 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -167,6 +167,17 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "airportsdata" +version = "20241001" +description = "Extensive database of location and timezone data for nearly every airport and landing strip in the world." +optional = true +python-versions = ">=3.9" +files = [ + {file = "airportsdata-20241001-py3-none-any.whl", hash = "sha256:67d71cf2c5378cc17ff66b62b1e11aa2444043949c894543ac8fd8dafce192fd"}, + {file = "airportsdata-20241001.tar.gz", hash = "sha256:fa0bd143b4f4be3557cb892fa0612ef210fd91a92bd720b4d8221de576a4fa00"}, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1043,17 +1054,6 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] -[[package]] -name = "joblib" -version = "1.4.2" -description = "Lightweight pipelining with Python functions" -optional = true -python-versions = ">=3.8" -files = [ - {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, - {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, -] - [[package]] name = "jsonschema" version = "4.23.0" @@ -1106,36 +1106,6 @@ interegular = ["interegular (>=0.3.1,<0.4.0)"] nearley = ["js2py"] regex = ["regex"] -[[package]] -name = "llvmlite" -version = "0.43.0" -description = "lightweight wrapper around basic LLVM functionality" -optional = true -python-versions = ">=3.9" -files = [ - {file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"}, - {file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"}, - {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"}, - {file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"}, - {file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"}, - {file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"}, - {file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"}, - {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"}, - {file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"}, - {file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"}, - {file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"}, - {file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"}, - {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"}, - {file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"}, - {file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"}, - {file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"}, - {file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"}, - {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"}, - {file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"}, - {file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"}, - {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, -] - [[package]] name = "loguru" version = "0.6.0" @@ -1577,40 +1547,6 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9. extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"] test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] -[[package]] -name = "numba" -version = "0.60.0" -description = "compiling Python code using LLVM" -optional = true -python-versions = ">=3.9" -files = [ - {file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"}, - {file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"}, - {file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"}, - {file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"}, - {file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"}, - {file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"}, - {file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"}, - {file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"}, - {file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"}, - {file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"}, - {file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"}, - {file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"}, - {file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"}, - {file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"}, - {file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"}, - {file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"}, - {file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"}, - {file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"}, - {file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"}, - {file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"}, - {file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"}, -] - -[package.dependencies] -llvmlite = "==0.43.*" -numpy = ">=1.22,<2.1" - [[package]] name = "numpy" version = "1.26.4" @@ -1988,36 +1924,83 @@ opentelemetry-api = "1.25.0" [[package]] name = "outlines" -version = "0.0.34" +version = "0.1.1" description = "Probabilistic Generative Model Programming" optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "outlines-0.0.34-py3-none-any.whl", hash = "sha256:911588a7e64a4f193b97fb4c501d98ccfd4e95a98f6a3ada67a280bf0c373c50"}, - {file = "outlines-0.0.34.tar.gz", hash = "sha256:594e7204c770b47a62eb5c2ba7d25ea0ab2e16882b5f04556712a0228d3d3309"}, + {file = "outlines-0.1.1-py3-none-any.whl", hash = "sha256:896aee7f8f0472955104bb30fb118e525bced6885f09e833bb848782394f2c17"}, + {file = "outlines-0.1.1.tar.gz", hash = "sha256:9c5d3524ef21343bd681757e8ed9a5b1fcb335ee68f9b6b0889062ce23b561fc"}, ] [package.dependencies] +airportsdata = "*" cloudpickle = "*" +datasets = "*" diskcache = "*" interegular = "*" jinja2 = "*" -joblib = "*" jsonschema = "*" lark = "*" nest-asyncio = "*" -numba = "*" -numpy = "*" +numpy = "<2.0.0" +outlines-core = "0.1.14" +pycountry = "*" pydantic = ">=2.0" referencing = "*" requests = "*" -scipy = "*" -torch = ">=2.1.0" -transformers = "*" +torch = "*" +tqdm = "*" +typing-extensions = "*" [package.extras] -serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] -test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python (>=0.2.42)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] +serve = ["fastapi", "pydantic (>=2.0)", "uvicorn", "vllm (>=0.3.0)"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "diff-cover", "exllamav2", "huggingface-hub", "llama-cpp-python", "mlx-lm", "openai (>=1.0.0)", "pillow", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers", "vllm"] + +[[package]] +name = "outlines-core" +version = "0.1.14" +description = "Structured Text Generation in Rust" +optional = true +python-versions = ">=3.8" +files = [ + {file = "outlines_core-0.1.14-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:291c6d9d348cb5562cd28ce44d80822d77238f1cd7c30d890b5b20488e71608d"}, + {file = "outlines_core-0.1.14-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3a50e2f6745e0c34cc857d1bd5590e2966ad06e8ce10802976e9e6c116c7533d"}, + {file = "outlines_core-0.1.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7dfe64b590a6a88dcc5e59f0a399fff0458cdcf97d68de07f08e1bd3bf8ac1d"}, + {file = "outlines_core-0.1.14-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:100de068ce52893bec316481e65db8f1c734a0f25f540c29dafd7a8afec0a29d"}, + {file = "outlines_core-0.1.14-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e06cb724770fd0fe1c8444382c4a6e79901bba33720f70fe6c8437f58eceb92e"}, + {file = "outlines_core-0.1.14-cp310-cp310-win32.whl", hash = "sha256:6d41da3d8a087fd54133cf910c2d5759da55490bbd0e3bc6c1e7907b54248415"}, + {file = "outlines_core-0.1.14-cp310-cp310-win_amd64.whl", hash = "sha256:646fd1073feed393bc77f9605a2fa27a54551ab04f85867ce789af1dee6326fa"}, + {file = "outlines_core-0.1.14-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:60f3a947fe09106f7668cf832c28b9269b8f0fc109f081608acfce9262213359"}, + {file = "outlines_core-0.1.14-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5e273a100c922f794d8e077a8161d0985d3005887066b4af3ae7afd3742fe9b8"}, + {file = "outlines_core-0.1.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:622e547f11a869fc67be40abc4cbcda89ae6f46f9eb46a1ec0666bd6807e0c67"}, + {file = "outlines_core-0.1.14-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:60c9933a9faaa51b39aea3518f1822b0d3ec2c9a13b16849caca3955e29e320d"}, + {file = "outlines_core-0.1.14-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a8c616ce103ef9102dbf4326f67b03e1e0f46aa19351e57f4beb37588c00428"}, + {file = "outlines_core-0.1.14-cp311-cp311-win32.whl", hash = "sha256:1c77aaa4556cbb6e93cc42be0a6e262f175e0754b7694d702d642ff03df67f2c"}, + {file = "outlines_core-0.1.14-cp311-cp311-win_amd64.whl", hash = "sha256:eb6ffe410866f65dbe17e95b0aabd70d990f058a2dc4e8b74f9583b07248cd36"}, + {file = "outlines_core-0.1.14-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b0e408b033618f23e9bb928a47b33b1bd4c9d04a3dbec680a20977de3b4f590d"}, + {file = "outlines_core-0.1.14-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:21d1393a6da5d3320e8c8247e9deeb851c5c862fd6ea5c779bd29797e8987155"}, + {file = "outlines_core-0.1.14-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5829c568db76673d36caaf0f86e96748b491b4a209deb9be87617372394a5fb9"}, + {file = "outlines_core-0.1.14-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e855ec99bce1099c0755bcbfa44568adf7ae0083905ba04f58a17614ddf0fe7"}, + {file = "outlines_core-0.1.14-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b897cfbf9c2719aa011d9b439b4c6751d9c7df5683b2169617972d4b4a914403"}, + {file = "outlines_core-0.1.14-cp38-cp38-win32.whl", hash = "sha256:4c9d908004b31bcd432156d60f4895bf5e1b51ca8c8eed82b12f1bb57d5bf7fd"}, + {file = "outlines_core-0.1.14-cp38-cp38-win_amd64.whl", hash = "sha256:6668a930d928216d0b319ad84947903f1e27556f604a9743051f795b11008b64"}, + {file = "outlines_core-0.1.14-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b745aa469cf3fb347b79a257804d75d1324e01691158664c1e413a816ce6b98d"}, + {file = "outlines_core-0.1.14-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:27504c8360467429d6223ebc49180d6956d7418bfc3d324f6ad10f069e1813ad"}, + {file = "outlines_core-0.1.14-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd8f1e1d91a206a520d1c577ce00136de2beb1d200ef93759fd4c9f45abe24d3"}, + {file = "outlines_core-0.1.14-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f30c8acb42895b624c504b85678331c5f9376fa4b8069ce06a27cf80f5881e27"}, + {file = "outlines_core-0.1.14-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0e6cd0e7d995a7b04d90139a695279ab4a9eb7f492618b2c037a85eaf5f9fc59"}, + {file = "outlines_core-0.1.14-cp39-cp39-win32.whl", hash = "sha256:3104af4084da0e7c3d4b8538b43c725581d66bb68d426bc389680f06c3667476"}, + {file = "outlines_core-0.1.14-cp39-cp39-win_amd64.whl", hash = "sha256:45c6b9baded0337c4dcfa156af05ec4efd2b25c4d976e77be28146e4037b991f"}, + {file = "outlines_core-0.1.14.tar.gz", hash = "sha256:6db033e4f8e48381164e36cc716746640ad5022f0d86e4c88af15c75886b93a4"}, +] + +[package.dependencies] +interegular = "*" +jsonschema = "*" + +[package.extras] +test = ["accelerate", "asv", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "pillow", "pre-commit", "pydantic", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "setuptools-rust", "torch", "transformers"] [[package]] name = "packaging" @@ -2490,6 +2473,17 @@ numpy = ">=1.16.6" [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pycountry" +version = "24.6.1" +description = "ISO country, subdivision, language, currency and script definitions and their translations" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pycountry-24.6.1-py3-none-any.whl", hash = "sha256:f1a4fb391cd7214f8eefd39556d740adcc233c778a27f8942c8dca351d6ce06f"}, + {file = "pycountry-24.6.1.tar.gz", hash = "sha256:b61b3faccea67f87d10c1f2b0fc0be714409e8fcdcc1315613174f6466c10221"}, +] + [[package]] name = "pydantic" version = "2.9.2" diff --git a/server/pyproject.toml b/server/pyproject.toml index 91ddfd6c..ca65b8c8 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -34,7 +34,7 @@ peft = { version = "^0.10", optional = true } torch = { version = "^2.4.0", optional = true } scipy = "^1.11.1" pillow = "^10.0.0" -outlines= { version = "^0.0.34", optional = true } +outlines= { version = "^0.1.1", optional = true } prometheus-client = "^0.20.0" py-cpuinfo = "^9.0.0" compressed-tensors = { version = "^0.7.1", optional = true } diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 9abd886f..ec2813a1 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -5,7 +5,7 @@ from loguru import logger from typing import Dict, Union from text_generation_server.pb.generate_pb2 import GrammarType -from outlines.fsm.fsm import RegexFSM +from outlines.fsm.guide import RegexGuide from outlines.fsm.json_schema import build_regex_from_schema from functools import lru_cache from typing import List, Optional, DefaultDict @@ -482,7 +482,7 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): class GrammarLogitProcessor(LogitsProcessor): fsm_state: DefaultDict[int, int] - fsm: RegexFSM + fsm: RegexGuide def __init__(self, tokenizer, device, grammar, grammar_type): self.device = device @@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor): ): if fsm_grammar_state == -1 or self.fsm is None: return logits - allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) + allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens mask = torch.full_like(logits, -math.inf) - mask[:, allowed_tokens] = 0 + if allowed_tokens is not None: + mask[:, allowed_tokens] = 0 biased_scores = logits + mask return biased_scores @@ -513,7 +514,7 @@ class GrammarLogitProcessor(LogitsProcessor): def _advance(next_token_id, fsm_grammar_state, fsm): if fsm_grammar_state == -1: return fsm_grammar_state - return fsm.next_state(fsm_grammar_state, next_token_id) + return fsm.get_next_state(fsm_grammar_state, next_token_id) # TODO: move grammar compilation into the router @staticmethod @@ -530,7 +531,7 @@ class GrammarLogitProcessor(LogitsProcessor): schema = "(.*?)" elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: pass # schema is already a regex just here for clarity - fsm = RegexFSM(schema, tokenizer) + fsm = RegexGuide.from_regex(schema, tokenizer) logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s") return fsm @@ -588,8 +589,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): fsm = self.fsms[i] if fsm_grammar_states[i] == -1 or fsm is None: continue - allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) - mask[i, allowed_tokens] = 0 + allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens + if allowed_tokens is not None: + mask[i, allowed_tokens] = 0 logits[i] += mask[i] return logits