This commit is contained in:
OlivierDehaene 2024-02-16 17:50:57 +01:00 committed by GitHub
parent 0f2daad8b9
commit 4139054b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1304 additions and 765 deletions

307
Cargo.lock generated
View File

@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.8.7"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01"
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
dependencies = [
"cfg-if",
"once_cell",
@ -54,9 +54,9 @@ dependencies = [
[[package]]
name = "anstyle"
version = "1.0.4"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87"
checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc"
[[package]]
name = "anstyle-parse"
@ -128,7 +128,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -139,7 +139,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -288,9 +288,9 @@ dependencies = [
[[package]]
name = "bumpalo"
version = "3.14.0"
version = "3.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f"
[[package]]
name = "bytecount"
@ -321,9 +321,9 @@ dependencies = [
[[package]]
name = "cargo-platform"
version = "0.1.6"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ceed8ef69d8518a5dda55c07425450b58a4e1946f4951eab6d7191ee86c2443d"
checksum = "694c8807f2ae16faecc43dc17d74b3eb042482789fd0eb64b39a2e04e087053f"
dependencies = [
"serde",
]
@ -365,9 +365,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
version = "4.4.18"
version = "4.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c"
checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f"
dependencies = [
"clap_builder",
"clap_derive",
@ -375,33 +375,33 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.4.18"
version = "4.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7"
checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
"strsim 0.11.0",
]
[[package]]
name = "clap_derive"
version = "4.4.7"
version = "4.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442"
checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
name = "clap_lex"
version = "0.6.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1"
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
[[package]]
name = "colorchoice"
@ -449,9 +449,9 @@ dependencies = [
[[package]]
name = "crc32fast"
version = "1.3.2"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa"
dependencies = [
"cfg-if",
]
@ -555,7 +555,7 @@ dependencies = [
"ident_case",
"proc-macro2",
"quote",
"strsim",
"strsim 0.10.0",
"syn 1.0.109",
]
@ -672,9 +672,9 @@ dependencies = [
[[package]]
name = "either"
version = "1.9.0"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a"
[[package]]
name = "encode_unicode"
@ -836,7 +836,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -927,7 +927,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http",
"indexmap 2.1.0",
"indexmap 2.2.3",
"slab",
"tokio",
"tokio-util",
@ -963,9 +963,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
version = "0.3.4"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f"
checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd"
[[package]]
name = "hf-hub"
@ -1125,9 +1125,9 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.1.0"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f"
checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
@ -1136,9 +1136,9 @@ dependencies = [
[[package]]
name = "indicatif"
version = "0.17.7"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25"
checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3"
dependencies = [
"console",
"instant",
@ -1199,6 +1199,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.10"
@ -1207,9 +1216,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "js-sys"
version = "0.3.67"
version = "0.3.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1"
checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee"
dependencies = [
"wasm-bindgen",
]
@ -1222,9 +1231,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.152"
version = "0.2.153"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7"
checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
[[package]]
name = "libm"
@ -1354,7 +1363,7 @@ checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -1405,9 +1414,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.7.1"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7"
checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7"
dependencies = [
"adler",
]
@ -1442,7 +1451,7 @@ checksum = "f686d68a09079e63b1d2c64aa305095887ce50565f00a922ebfaeeee0d9ba6ce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -1566,10 +1575,16 @@ dependencies = [
]
[[package]]
name = "num-traits"
version = "0.2.17"
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a"
dependencies = [
"autocfg",
"libm",
@ -1587,9 +1602,9 @@ dependencies = [
[[package]]
name = "num_threads"
version = "0.1.6"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9"
dependencies = [
"libc",
]
@ -1660,7 +1675,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -1856,7 +1871,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
dependencies = [
"fixedbitset",
"indexmap 2.1.0",
"indexmap 2.2.3",
]
[[package]]
@ -1876,7 +1891,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -1893,9 +1908,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
version = "0.3.29"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb"
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
[[package]]
name = "portable-atomic"
@ -1922,7 +1937,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5"
dependencies = [
"proc-macro2",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -1995,7 +2010,7 @@ dependencies = [
"prost 0.12.3",
"prost-types",
"regex",
"syn 2.0.48",
"syn 2.0.49",
"tempfile",
"which",
]
@ -2023,7 +2038,7 @@ dependencies = [
"itertools 0.11.0",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -2219,9 +2234,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
name = "reqwest"
version = "0.11.23"
version = "0.11.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41"
checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251"
dependencies = [
"base64 0.21.7",
"bytes",
@ -2241,9 +2256,11 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-native-tls",
@ -2305,7 +2322,7 @@ dependencies = [
"quote",
"rust-embed-utils",
"shellexpand",
"syn 2.0.48",
"syn 2.0.49",
"walkdir",
]
@ -2336,9 +2353,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.30"
version = "0.38.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca"
checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949"
dependencies = [
"bitflags 2.4.2",
"errno",
@ -2361,14 +2378,16 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.21.10"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [
"log",
"ring 0.17.7",
"rustls-pki-types",
"rustls-webpki",
"sct",
"subtle",
"zeroize",
]
[[package]]
@ -2381,12 +2400,19 @@ dependencies = [
]
[[package]]
name = "rustls-webpki"
version = "0.101.7"
name = "rustls-pki-types"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
checksum = "048a63e5b3ac996d78d402940b5fa47973d2d080c6c6fffa1d0f19c4445310b7"
[[package]]
name = "rustls-webpki"
version = "0.102.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
dependencies = [
"ring 0.17.7",
"rustls-pki-types",
"untrusted 0.9.0",
]
@ -2470,29 +2496,29 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.195"
version = "1.0.196"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02"
checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.195"
version = "1.0.196"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c"
checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
name = "serde_json"
version = "1.0.111"
version = "1.0.113"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4"
checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79"
dependencies = [
"itoa",
"ryu",
@ -2582,9 +2608,9 @@ dependencies = [
[[package]]
name = "sketches-ddsketch"
version = "0.2.1"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a406c1882ed7f29cd5e248c9848a80e7cb6ae0fea82346d2746f2f941c07e1"
checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c"
[[package]]
name = "slab"
@ -2650,6 +2676,12 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "strsim"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
[[package]]
name = "strum"
version = "0.25.0"
@ -2669,9 +2701,15 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
name = "subtle"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
name = "syn"
version = "1.0.109"
@ -2685,9 +2723,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.48"
version = "2.0.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496"
dependencies = [
"proc-macro2",
"quote",
@ -2761,20 +2799,19 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.9.0"
version = "3.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa"
checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67"
dependencies = [
"cfg-if",
"fastrand",
"redox_syscall",
"rustix",
"windows-sys 0.52.0",
]
[[package]]
name = "text-generation-benchmark"
version = "1.4.0"
version = "1.4.1"
dependencies = [
"average",
"clap",
@ -2795,7 +2832,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "1.4.0"
version = "1.4.1"
dependencies = [
"futures",
"grpc-metadata",
@ -2811,7 +2848,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "1.4.0"
version = "1.4.1"
dependencies = [
"clap",
"ctrlc",
@ -2827,7 +2864,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "1.4.0"
version = "1.4.1"
dependencies = [
"async-stream",
"axum",
@ -2850,7 +2887,7 @@ dependencies = [
"serde_json",
"text-generation-client",
"thiserror",
"tokenizers 0.15.1",
"tokenizers 0.15.2",
"tokio",
"tokio-stream",
"tower-http",
@ -2864,22 +2901,22 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.56"
version = "1.0.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.56"
version = "1.0.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -2894,13 +2931,14 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.31"
version = "0.3.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e"
checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749"
dependencies = [
"deranged",
"itoa",
"libc",
"num-conv",
"num_threads",
"powerfmt",
"serde",
@ -2916,10 +2954,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.16"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f"
checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774"
dependencies = [
"num-conv",
"time-core",
]
@ -2974,9 +3013,9 @@ dependencies = [
[[package]]
name = "tokenizers"
version = "0.15.1"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d"
dependencies = [
"aho-corasick",
"clap",
@ -2985,7 +3024,7 @@ dependencies = [
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
@ -2996,7 +3035,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"regex-syntax 0.8.2",
"serde",
"serde_json",
"spm_precompiled",
@ -3008,9 +3047,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.35.1"
version = "1.36.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"
dependencies = [
"backtrace",
"bytes",
@ -3043,7 +3082,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -3158,7 +3197,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -3231,7 +3270,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -3379,9 +3418,9 @@ dependencies = [
[[package]]
name = "unicode-segmentation"
version = "1.10.1"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202"
[[package]]
name = "unicode-width"
@ -3409,16 +3448,17 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.9.1"
version = "2.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35"
dependencies = [
"base64 0.21.7",
"flate2",
"log",
"native-tls",
"once_cell",
"rustls 0.21.10",
"rustls 0.22.2",
"rustls-pki-types",
"rustls-webpki",
"serde",
"serde_json",
@ -3455,7 +3495,7 @@ version = "3.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82b1bc5417102a73e8464c686eef947bdfb99fcdfc0a4f228e81afa9526470a"
dependencies = [
"indexmap 2.1.0",
"indexmap 2.2.3",
"serde",
"serde_json",
"utoipa-gen",
@ -3471,7 +3511,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
@ -3551,9 +3591,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasm-bindgen"
version = "0.2.90"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406"
checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
@ -3561,24 +3601,24 @@ dependencies = [
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.90"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd"
checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.40"
version = "0.4.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461"
checksum = "877b9c3f61ceea0e56331985743b13f3d25c406a7098d45180fb5f09bc19ed97"
dependencies = [
"cfg-if",
"js-sys",
@ -3588,9 +3628,9 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.90"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999"
checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
@ -3598,28 +3638,28 @@ dependencies = [
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.90"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7"
checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.90"
version = "0.2.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b"
checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838"
[[package]]
name = "web-sys"
version = "0.3.67"
version = "0.3.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed"
checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446"
dependencies = [
"js-sys",
"wasm-bindgen",
@ -3637,9 +3677,12 @@ dependencies = [
[[package]]
name = "webpki-roots"
version = "0.25.3"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
@ -3928,9 +3971,15 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"syn 2.0.49",
]
[[package]]
name = "zeroize"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
[[package]]
name = "zip"
version = "0.6.6"

View File

@ -9,7 +9,7 @@ members = [
resolver = "2"
[workspace.package]
version = "1.4.0"
version = "1.4.1"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -225,7 +225,7 @@ COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -150,7 +150,7 @@ COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_rocm.txt && \
pip install ".[accelerate, peft]" --no-cache-dir
pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "1.4.0"
"version": "1.4.1"
},
"paths": {
"/": {
@ -590,8 +590,11 @@
"minimum": 0
},
"logprobs": {
"type": "number",
"format": "float",
"allOf": [
{
"$ref": "#/components/schemas/ChatCompletionLogprobs"
}
],
"nullable": true
}
}
@ -710,7 +713,7 @@
"presence_penalty": {
"type": "number",
"format": "float",
"description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
"description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
"example": 0.1,
"nullable": true
},
@ -734,7 +737,7 @@
"top_logprobs": {
"type": "integer",
"format": "int32",
"description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
"description": "An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
"example": "5",
"nullable": true,
"minimum": 0
@ -870,6 +873,22 @@
"default": "false",
"example": true
},
"frequency_penalty": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.1,
"nullable": true,
"exclusiveMinimum": -2
},
"grammar": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"nullable": true
},
"max_new_tokens": {
"type": "integer",
"format": "int32",
@ -1026,6 +1045,12 @@
"example": "null",
"nullable": true
},
"max_batch_size": {
"type": "integer",
"example": "null",
"nullable": true,
"minimum": 0
},
"max_batch_total_tokens": {
"type": "integer",
"format": "int32",
@ -1119,6 +1144,11 @@
"type": "string",
"example": "My name is David and I"
},
"name": {
"type": "string",
"example": "\"David\"",
"nullable": true
},
"role": {
"type": "string",
"example": "user"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-integration-tests"
version = "1.4.0"
version = "1.4.1"
description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]

View File

@ -23,7 +23,7 @@ install-megablocks:
install: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft]"
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded

1635
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "text-generation-server"
version = "1.4.0"
version = "1.4.1"
description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines="^0.0.27"
outlines= { version = "^0.0.27", optional = true }
[tool.poetry.extras]
torch = ["torch"]
@ -42,6 +42,7 @@ accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate"]
outlines = ["outlines"]
[tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1"

View File

@ -1,6 +1,6 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
@ -10,14 +10,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -29,19 +29,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
@ -9,14 +9,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -28,19 +28,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,10 +1,8 @@
import math
import torch
import json
from loguru import logger
from functools import lru_cache
from typing import Optional, List, Dict, Union
from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
@ -492,7 +490,7 @@ class GrammarLogitProcessor(LogitsProcessor):
if fsm_grammar_state == -1 or self.fsm is None:
return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = 0
biased_scores = logits + mask
return biased_scores
@ -550,22 +548,15 @@ class GrammarLogitProcessor(LogitsProcessor):
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_type):
def __init__(self, tokenizer, device, grammars, grammar_types):
self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = []
for i in range(len(grammars)):
for grammar, grammar_type in zip(grammars, grammar_types):
fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type[i], grammars[i], self.tokenizer
grammar_type, grammar, self.tokenizer
)
self.fsms.append(fsm)
@ -573,7 +564,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self,
logits: torch.Tensor,
fsm_grammar_states: List[int],
mask: torch.Tensor,
):
mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]):
@ -585,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
logits += mask
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
def advance_batch(self, next_token_ids, fsm_grammar_states):
return [
GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
@ -599,4 +589,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
)
def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices)
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self

View File

@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers:
_scores = warper(input_ids, _scores)
if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
_next_ids = self.choice(_scores)
scores[:, j] = _scores
next_ids[:, j] = _next_ids
@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser:
def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states, self.grammars
next_ids, self.fsm_grammar_states
)
self.fsm_grammar_states = other_new_states
return self