Upgrade outlines to 0.1.1 (#2742)

* Upgrade outlines to 0.1.1

* Update for new API

* Check if allowed tokens is None

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Alex Weston 2024-11-15 07:22:52 -05:00 committed by GitHub
parent 003eaec0fb
commit 4580ced091
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 93 additions and 97 deletions

170
server/poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "accelerate"
@ -167,6 +167,17 @@ files = [
[package.dependencies]
frozenlist = ">=1.1.0"
[[package]]
name = "airportsdata"
version = "20241001"
description = "Extensive database of location and timezone data for nearly every airport and landing strip in the world."
optional = true
python-versions = ">=3.9"
files = [
{file = "airportsdata-20241001-py3-none-any.whl", hash = "sha256:67d71cf2c5378cc17ff66b62b1e11aa2444043949c894543ac8fd8dafce192fd"},
{file = "airportsdata-20241001.tar.gz", hash = "sha256:fa0bd143b4f4be3557cb892fa0612ef210fd91a92bd720b4d8221de576a4fa00"},
]
[[package]]
name = "annotated-types"
version = "0.7.0"
@ -1043,17 +1054,6 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "joblib"
version = "1.4.2"
description = "Lightweight pipelining with Python functions"
optional = true
python-versions = ">=3.8"
files = [
{file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
{file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
]
[[package]]
name = "jsonschema"
version = "4.23.0"
@ -1106,36 +1106,6 @@ interegular = ["interegular (>=0.3.1,<0.4.0)"]
nearley = ["js2py"]
regex = ["regex"]
[[package]]
name = "llvmlite"
version = "0.43.0"
description = "lightweight wrapper around basic LLVM functionality"
optional = true
python-versions = ">=3.9"
files = [
{file = "llvmlite-0.43.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a289af9a1687c6cf463478f0fa8e8aa3b6fb813317b0d70bf1ed0759eab6f761"},
{file = "llvmlite-0.43.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d4fd101f571a31acb1559ae1af30f30b1dc4b3186669f92ad780e17c81e91bc"},
{file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d434ec7e2ce3cc8f452d1cd9a28591745de022f931d67be688a737320dfcead"},
{file = "llvmlite-0.43.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6912a87782acdff6eb8bf01675ed01d60ca1f2551f8176a300a886f09e836a6a"},
{file = "llvmlite-0.43.0-cp310-cp310-win_amd64.whl", hash = "sha256:14f0e4bf2fd2d9a75a3534111e8ebeb08eda2f33e9bdd6dfa13282afacdde0ed"},
{file = "llvmlite-0.43.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8d0618cb9bfe40ac38a9633f2493d4d4e9fcc2f438d39a4e854f39cc0f5f98"},
{file = "llvmlite-0.43.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e0a9a1a39d4bf3517f2af9d23d479b4175ead205c592ceeb8b89af48a327ea57"},
{file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1da416ab53e4f7f3bc8d4eeba36d801cc1894b9fbfbf2022b29b6bad34a7df2"},
{file = "llvmlite-0.43.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977525a1e5f4059316b183fb4fd34fa858c9eade31f165427a3977c95e3ee749"},
{file = "llvmlite-0.43.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5bd550001d26450bd90777736c69d68c487d17bf371438f975229b2b8241a91"},
{file = "llvmlite-0.43.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f99b600aa7f65235a5a05d0b9a9f31150c390f31261f2a0ba678e26823ec38f7"},
{file = "llvmlite-0.43.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:35d80d61d0cda2d767f72de99450766250560399edc309da16937b93d3b676e7"},
{file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eccce86bba940bae0d8d48ed925f21dbb813519169246e2ab292b5092aba121f"},
{file = "llvmlite-0.43.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df6509e1507ca0760787a199d19439cc887bfd82226f5af746d6977bd9f66844"},
{file = "llvmlite-0.43.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a2872ee80dcf6b5dbdc838763d26554c2a18aa833d31a2635bff16aafefb9c9"},
{file = "llvmlite-0.43.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9cd2a7376f7b3367019b664c21f0c61766219faa3b03731113ead75107f3b66c"},
{file = "llvmlite-0.43.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18e9953c748b105668487b7c81a3e97b046d8abf95c4ddc0cd3c94f4e4651ae8"},
{file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74937acd22dc11b33946b67dca7680e6d103d6e90eeaaaf932603bec6fe7b03a"},
{file = "llvmlite-0.43.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc9efc739cc6ed760f795806f67889923f7274276f0eb45092a1473e40d9b867"},
{file = "llvmlite-0.43.0-cp39-cp39-win_amd64.whl", hash = "sha256:47e147cdda9037f94b399bf03bfd8a6b6b1f2f90be94a454e3386f006455a9b4"},
{file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"},
]
[[package]]
name = "loguru"
version = "0.6.0"
@ -1577,40 +1547,6 @@ doc = ["nb2plots (>=0.7)", "nbconvert (<7.9)", "numpydoc (>=1.6)", "pillow (>=9.
extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.11)", "sympy (>=1.10)"]
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "numba"
version = "0.60.0"
description = "compiling Python code using LLVM"
optional = true
python-versions = ">=3.9"
files = [
{file = "numba-0.60.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d761de835cd38fb400d2c26bb103a2726f548dc30368853121d66201672e651"},
{file = "numba-0.60.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:159e618ef213fba758837f9837fb402bbe65326e60ba0633dbe6c7f274d42c1b"},
{file = "numba-0.60.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1527dc578b95c7c4ff248792ec33d097ba6bef9eda466c948b68dfc995c25781"},
{file = "numba-0.60.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe0b28abb8d70f8160798f4de9d486143200f34458d34c4a214114e445d7124e"},
{file = "numba-0.60.0-cp310-cp310-win_amd64.whl", hash = "sha256:19407ced081d7e2e4b8d8c36aa57b7452e0283871c296e12d798852bc7d7f198"},
{file = "numba-0.60.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a17b70fc9e380ee29c42717e8cc0bfaa5556c416d94f9aa96ba13acb41bdece8"},
{file = "numba-0.60.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb02b344a2a80efa6f677aa5c40cd5dd452e1b35f8d1c2af0dfd9ada9978e4b"},
{file = "numba-0.60.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f4fde652ea604ea3c86508a3fb31556a6157b2c76c8b51b1d45eb40c8598703"},
{file = "numba-0.60.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4142d7ac0210cc86432b818338a2bc368dc773a2f5cf1e32ff7c5b378bd63ee8"},
{file = "numba-0.60.0-cp311-cp311-win_amd64.whl", hash = "sha256:cac02c041e9b5bc8cf8f2034ff6f0dbafccd1ae9590dc146b3a02a45e53af4e2"},
{file = "numba-0.60.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7da4098db31182fc5ffe4bc42c6f24cd7d1cb8a14b59fd755bfee32e34b8404"},
{file = "numba-0.60.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:38d6ea4c1f56417076ecf8fc327c831ae793282e0ff51080c5094cb726507b1c"},
{file = "numba-0.60.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:62908d29fb6a3229c242e981ca27e32a6e606cc253fc9e8faeb0e48760de241e"},
{file = "numba-0.60.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0ebaa91538e996f708f1ab30ef4d3ddc344b64b5227b67a57aa74f401bb68b9d"},
{file = "numba-0.60.0-cp312-cp312-win_amd64.whl", hash = "sha256:f75262e8fe7fa96db1dca93d53a194a38c46da28b112b8a4aca168f0df860347"},
{file = "numba-0.60.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:01ef4cd7d83abe087d644eaa3d95831b777aa21d441a23703d649e06b8e06b74"},
{file = "numba-0.60.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:819a3dfd4630d95fd574036f99e47212a1af41cbcb019bf8afac63ff56834449"},
{file = "numba-0.60.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b983bd6ad82fe868493012487f34eae8bf7dd94654951404114f23c3466d34b"},
{file = "numba-0.60.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c151748cd269ddeab66334bd754817ffc0cabd9433acb0f551697e5151917d25"},
{file = "numba-0.60.0-cp39-cp39-win_amd64.whl", hash = "sha256:3031547a015710140e8c87226b4cfe927cac199835e5bf7d4fe5cb64e814e3ab"},
{file = "numba-0.60.0.tar.gz", hash = "sha256:5df6158e5584eece5fc83294b949fd30b9f1125df7708862205217e068aabf16"},
]
[package.dependencies]
llvmlite = "==0.43.*"
numpy = ">=1.22,<2.1"
[[package]]
name = "numpy"
version = "1.26.4"
@ -1988,36 +1924,83 @@ opentelemetry-api = "1.25.0"
[[package]]
name = "outlines"
version = "0.0.34"
version = "0.1.1"
description = "Probabilistic Generative Model Programming"
optional = true
python-versions = ">=3.8"
python-versions = ">=3.9"
files = [
{file = "outlines-0.0.34-py3-none-any.whl", hash = "sha256:911588a7e64a4f193b97fb4c501d98ccfd4e95a98f6a3ada67a280bf0c373c50"},
{file = "outlines-0.0.34.tar.gz", hash = "sha256:594e7204c770b47a62eb5c2ba7d25ea0ab2e16882b5f04556712a0228d3d3309"},
{file = "outlines-0.1.1-py3-none-any.whl", hash = "sha256:896aee7f8f0472955104bb30fb118e525bced6885f09e833bb848782394f2c17"},
{file = "outlines-0.1.1.tar.gz", hash = "sha256:9c5d3524ef21343bd681757e8ed9a5b1fcb335ee68f9b6b0889062ce23b561fc"},
]
[package.dependencies]
airportsdata = "*"
cloudpickle = "*"
datasets = "*"
diskcache = "*"
interegular = "*"
jinja2 = "*"
joblib = "*"
jsonschema = "*"
lark = "*"
nest-asyncio = "*"
numba = "*"
numpy = "*"
numpy = "<2.0.0"
outlines-core = "0.1.14"
pycountry = "*"
pydantic = ">=2.0"
referencing = "*"
requests = "*"
scipy = "*"
torch = ">=2.1.0"
transformers = "*"
torch = "*"
tqdm = "*"
typing-extensions = "*"
[package.extras]
serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"]
test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python (>=0.2.42)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"]
serve = ["fastapi", "pydantic (>=2.0)", "uvicorn", "vllm (>=0.3.0)"]
test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "diff-cover", "exllamav2", "huggingface-hub", "llama-cpp-python", "mlx-lm", "openai (>=1.0.0)", "pillow", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers", "vllm"]
[[package]]
name = "outlines-core"
version = "0.1.14"
description = "Structured Text Generation in Rust"
optional = true
python-versions = ">=3.8"
files = [
{file = "outlines_core-0.1.14-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:291c6d9d348cb5562cd28ce44d80822d77238f1cd7c30d890b5b20488e71608d"},
{file = "outlines_core-0.1.14-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3a50e2f6745e0c34cc857d1bd5590e2966ad06e8ce10802976e9e6c116c7533d"},
{file = "outlines_core-0.1.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7dfe64b590a6a88dcc5e59f0a399fff0458cdcf97d68de07f08e1bd3bf8ac1d"},
{file = "outlines_core-0.1.14-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:100de068ce52893bec316481e65db8f1c734a0f25f540c29dafd7a8afec0a29d"},
{file = "outlines_core-0.1.14-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e06cb724770fd0fe1c8444382c4a6e79901bba33720f70fe6c8437f58eceb92e"},
{file = "outlines_core-0.1.14-cp310-cp310-win32.whl", hash = "sha256:6d41da3d8a087fd54133cf910c2d5759da55490bbd0e3bc6c1e7907b54248415"},
{file = "outlines_core-0.1.14-cp310-cp310-win_amd64.whl", hash = "sha256:646fd1073feed393bc77f9605a2fa27a54551ab04f85867ce789af1dee6326fa"},
{file = "outlines_core-0.1.14-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:60f3a947fe09106f7668cf832c28b9269b8f0fc109f081608acfce9262213359"},
{file = "outlines_core-0.1.14-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5e273a100c922f794d8e077a8161d0985d3005887066b4af3ae7afd3742fe9b8"},
{file = "outlines_core-0.1.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:622e547f11a869fc67be40abc4cbcda89ae6f46f9eb46a1ec0666bd6807e0c67"},
{file = "outlines_core-0.1.14-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:60c9933a9faaa51b39aea3518f1822b0d3ec2c9a13b16849caca3955e29e320d"},
{file = "outlines_core-0.1.14-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a8c616ce103ef9102dbf4326f67b03e1e0f46aa19351e57f4beb37588c00428"},
{file = "outlines_core-0.1.14-cp311-cp311-win32.whl", hash = "sha256:1c77aaa4556cbb6e93cc42be0a6e262f175e0754b7694d702d642ff03df67f2c"},
{file = "outlines_core-0.1.14-cp311-cp311-win_amd64.whl", hash = "sha256:eb6ffe410866f65dbe17e95b0aabd70d990f058a2dc4e8b74f9583b07248cd36"},
{file = "outlines_core-0.1.14-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b0e408b033618f23e9bb928a47b33b1bd4c9d04a3dbec680a20977de3b4f590d"},
{file = "outlines_core-0.1.14-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:21d1393a6da5d3320e8c8247e9deeb851c5c862fd6ea5c779bd29797e8987155"},
{file = "outlines_core-0.1.14-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5829c568db76673d36caaf0f86e96748b491b4a209deb9be87617372394a5fb9"},
{file = "outlines_core-0.1.14-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7e855ec99bce1099c0755bcbfa44568adf7ae0083905ba04f58a17614ddf0fe7"},
{file = "outlines_core-0.1.14-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b897cfbf9c2719aa011d9b439b4c6751d9c7df5683b2169617972d4b4a914403"},
{file = "outlines_core-0.1.14-cp38-cp38-win32.whl", hash = "sha256:4c9d908004b31bcd432156d60f4895bf5e1b51ca8c8eed82b12f1bb57d5bf7fd"},
{file = "outlines_core-0.1.14-cp38-cp38-win_amd64.whl", hash = "sha256:6668a930d928216d0b319ad84947903f1e27556f604a9743051f795b11008b64"},
{file = "outlines_core-0.1.14-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b745aa469cf3fb347b79a257804d75d1324e01691158664c1e413a816ce6b98d"},
{file = "outlines_core-0.1.14-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:27504c8360467429d6223ebc49180d6956d7418bfc3d324f6ad10f069e1813ad"},
{file = "outlines_core-0.1.14-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd8f1e1d91a206a520d1c577ce00136de2beb1d200ef93759fd4c9f45abe24d3"},
{file = "outlines_core-0.1.14-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f30c8acb42895b624c504b85678331c5f9376fa4b8069ce06a27cf80f5881e27"},
{file = "outlines_core-0.1.14-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0e6cd0e7d995a7b04d90139a695279ab4a9eb7f492618b2c037a85eaf5f9fc59"},
{file = "outlines_core-0.1.14-cp39-cp39-win32.whl", hash = "sha256:3104af4084da0e7c3d4b8538b43c725581d66bb68d426bc389680f06c3667476"},
{file = "outlines_core-0.1.14-cp39-cp39-win_amd64.whl", hash = "sha256:45c6b9baded0337c4dcfa156af05ec4efd2b25c4d976e77be28146e4037b991f"},
{file = "outlines_core-0.1.14.tar.gz", hash = "sha256:6db033e4f8e48381164e36cc716746640ad5022f0d86e4c88af15c75886b93a4"},
]
[package.dependencies]
interegular = "*"
jsonschema = "*"
[package.extras]
test = ["accelerate", "asv", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "pillow", "pre-commit", "pydantic", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "setuptools-rust", "torch", "transformers"]
[[package]]
name = "packaging"
@ -2490,6 +2473,17 @@ numpy = ">=1.16.6"
[package.extras]
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
[[package]]
name = "pycountry"
version = "24.6.1"
description = "ISO country, subdivision, language, currency and script definitions and their translations"
optional = true
python-versions = ">=3.8"
files = [
{file = "pycountry-24.6.1-py3-none-any.whl", hash = "sha256:f1a4fb391cd7214f8eefd39556d740adcc233c778a27f8942c8dca351d6ce06f"},
{file = "pycountry-24.6.1.tar.gz", hash = "sha256:b61b3faccea67f87d10c1f2b0fc0be714409e8fcdcc1315613174f6466c10221"},
]
[[package]]
name = "pydantic"
version = "2.9.2"

View File

@ -34,7 +34,7 @@ peft = { version = "^0.10", optional = true }
torch = { version = "^2.4.0", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines= { version = "^0.0.34", optional = true }
outlines= { version = "^0.1.1", optional = true }
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
compressed-tensors = { version = "^0.7.1", optional = true }

View File

@ -5,7 +5,7 @@ from loguru import logger
from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
@ -482,7 +482,7 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
class GrammarLogitProcessor(LogitsProcessor):
fsm_state: DefaultDict[int, int]
fsm: RegexFSM
fsm: RegexGuide
def __init__(self, tokenizer, device, grammar, grammar_type):
self.device = device
@ -498,9 +498,10 @@ class GrammarLogitProcessor(LogitsProcessor):
):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens
mask = torch.full_like(logits, -math.inf)
mask[:, allowed_tokens] = 0
if allowed_tokens is not None:
mask[:, allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
@ -513,7 +514,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _advance(next_token_id, fsm_grammar_state, fsm):
if fsm_grammar_state == -1:
return fsm_grammar_state
return fsm.next_state(fsm_grammar_state, next_token_id)
return fsm.get_next_state(fsm_grammar_state, next_token_id)
# TODO: move grammar compilation into the router
@staticmethod
@ -530,7 +531,7 @@ class GrammarLogitProcessor(LogitsProcessor):
schema = "(.*?)"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
fsm = RegexGuide.from_regex(schema, tokenizer)
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
return fsm
@ -588,8 +589,9 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0
allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens
if allowed_tokens is not None:
mask[i, allowed_tokens] = 0
logits[i] += mask[i]
return logits