feat(router): make router input validation optional (#164)
This commit is contained in:
parent
7dec65a244
commit
9987960062
File diff suppressed because it is too large
Load Diff
|
@ -8,6 +8,18 @@ version = "1.0.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aes"
|
||||||
|
version = "0.7.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9e8b47f52ea9bae42228d07ec09eb676433d7c4ed1ebdf0f1d1c29ed446f1ab8"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cipher",
|
||||||
|
"cpufeatures",
|
||||||
|
"opaque-debug",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aho-corasick"
|
name = "aho-corasick"
|
||||||
version = "0.7.20"
|
version = "0.7.20"
|
||||||
|
@ -17,15 +29,6 @@ dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ansi_term"
|
|
||||||
version = "0.12.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
|
|
||||||
dependencies = [
|
|
||||||
"winapi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anstream"
|
name = "anstream"
|
||||||
version = "0.2.6"
|
version = "0.2.6"
|
||||||
|
@ -105,17 +108,6 @@ dependencies = [
|
||||||
"syn 2.0.11",
|
"syn 2.0.11",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "atty"
|
|
||||||
version = "0.2.14"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
|
||||||
dependencies = [
|
|
||||||
"hermit-abi 0.1.19",
|
|
||||||
"libc",
|
|
||||||
"winapi",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "autocfg"
|
name = "autocfg"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -190,6 +182,12 @@ version = "0.21.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "base64ct"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
|
@ -246,9 +244,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cached-path"
|
name = "cached-path"
|
||||||
version = "0.5.3"
|
version = "0.6.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5f1c56d30236522ab3393a08746b138d4e16372001f42d29c88d513aeb8ab7ef"
|
checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"flate2",
|
"flate2",
|
||||||
"fs2",
|
"fs2",
|
||||||
|
@ -264,7 +262,6 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"zip",
|
"zip",
|
||||||
"zip-extensions",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -278,6 +275,9 @@ name = "cc"
|
||||||
version = "1.0.79"
|
version = "1.0.79"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
|
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
|
||||||
|
dependencies = [
|
||||||
|
"jobserver",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cfg-if"
|
name = "cfg-if"
|
||||||
|
@ -286,18 +286,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "cipher"
|
||||||
version = "2.34.0"
|
version = "0.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ansi_term",
|
"generic-array",
|
||||||
"atty",
|
|
||||||
"bitflags",
|
|
||||||
"strsim 0.8.0",
|
|
||||||
"textwrap",
|
|
||||||
"unicode-width",
|
|
||||||
"vec_map",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -321,7 +315,7 @@ dependencies = [
|
||||||
"anstyle",
|
"anstyle",
|
||||||
"bitflags",
|
"bitflags",
|
||||||
"clap_lex",
|
"clap_lex",
|
||||||
"strsim 0.10.0",
|
"strsim",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -370,6 +364,12 @@ dependencies = [
|
||||||
"windows-sys 0.42.0",
|
"windows-sys 0.42.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "constant_time_eq"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "core-foundation"
|
name = "core-foundation"
|
||||||
version = "0.9.3"
|
version = "0.9.3"
|
||||||
|
@ -484,9 +484,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling"
|
name = "darling"
|
||||||
version = "0.10.2"
|
version = "0.14.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
|
checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"darling_macro",
|
"darling_macro",
|
||||||
|
@ -494,23 +494,23 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling_core"
|
name = "darling_core"
|
||||||
version = "0.10.2"
|
version = "0.14.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
|
checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fnv",
|
"fnv",
|
||||||
"ident_case",
|
"ident_case",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"strsim 0.9.3",
|
"strsim",
|
||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "darling_macro"
|
name = "darling_macro"
|
||||||
version = "0.10.2"
|
version = "0.14.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
|
checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling_core",
|
"darling_core",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -532,26 +532,32 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder"
|
name = "derive_builder"
|
||||||
version = "0.9.0"
|
version = "0.12.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
|
checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
|
||||||
|
dependencies = [
|
||||||
|
"derive_builder_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder_core"
|
||||||
|
version = "0.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling",
|
"darling",
|
||||||
"derive_builder_core",
|
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "derive_builder_core"
|
name = "derive_builder_macro"
|
||||||
version = "0.9.0"
|
version = "0.12.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
|
checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling",
|
"derive_builder_core",
|
||||||
"proc-macro2",
|
|
||||||
"quote",
|
|
||||||
"syn 1.0.109",
|
"syn 1.0.109",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -563,13 +569,14 @@ checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"block-buffer",
|
"block-buffer",
|
||||||
"crypto-common",
|
"crypto-common",
|
||||||
|
"subtle",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dirs"
|
name = "dirs"
|
||||||
version = "3.0.2"
|
version = "4.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
|
checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"dirs-sys",
|
"dirs-sys",
|
||||||
]
|
]
|
||||||
|
@ -835,7 +842,7 @@ checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"libc",
|
"libc",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -885,15 +892,6 @@ version = "0.4.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "hermit-abi"
|
|
||||||
version = "0.1.19"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
|
||||||
dependencies = [
|
|
||||||
"libc",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hermit-abi"
|
name = "hermit-abi"
|
||||||
version = "0.2.6"
|
version = "0.2.6"
|
||||||
|
@ -909,6 +907,15 @@ version = "0.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hmac"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||||
|
dependencies = [
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "0.2.9"
|
version = "0.2.9"
|
||||||
|
@ -1113,6 +1120,15 @@ version = "1.0.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jobserver"
|
||||||
|
version = "0.1.26"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "js-sys"
|
name = "js-sys"
|
||||||
version = "0.3.61"
|
version = "0.3.61"
|
||||||
|
@ -1240,10 +1256,31 @@ checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"log",
|
"log",
|
||||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
"wasi",
|
||||||
"windows-sys 0.45.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monostate"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a"
|
||||||
|
dependencies = [
|
||||||
|
"monostate-impl",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "monostate-impl"
|
||||||
|
version = "0.1.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.11",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multimap"
|
name = "multimap"
|
||||||
version = "0.8.3"
|
version = "0.8.3"
|
||||||
|
@ -1348,6 +1385,12 @@ dependencies = [
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "opaque-debug"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openssl"
|
name = "openssl"
|
||||||
version = "0.10.48"
|
version = "0.10.48"
|
||||||
|
@ -1468,12 +1511,35 @@ dependencies = [
|
||||||
"windows-sys 0.45.0",
|
"windows-sys 0.45.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "password-hash"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
|
||||||
|
dependencies = [
|
||||||
|
"base64ct",
|
||||||
|
"rand_core",
|
||||||
|
"subtle",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "paste"
|
name = "paste"
|
||||||
version = "1.0.12"
|
version = "1.0.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
|
checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pbkdf2"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917"
|
||||||
|
dependencies = [
|
||||||
|
"digest",
|
||||||
|
"hmac",
|
||||||
|
"password-hash",
|
||||||
|
"sha2",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
|
@ -1891,6 +1957,17 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sha1"
|
||||||
|
version = "0.10.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"cpufeatures",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
|
@ -1978,24 +2055,18 @@ dependencies = [
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "strsim"
|
|
||||||
version = "0.8.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "strsim"
|
|
||||||
version = "0.9.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "subtle"
|
||||||
|
version = "2.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.109"
|
version = "1.0.109"
|
||||||
|
@ -2053,7 +2124,7 @@ name = "text-generation-benchmark"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.2.1",
|
"clap",
|
||||||
"crossterm",
|
"crossterm",
|
||||||
"float-ord",
|
"float-ord",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
|
@ -2084,15 +2155,6 @@ dependencies = [
|
||||||
"tracing-error",
|
"tracing-error",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "textwrap"
|
|
||||||
version = "0.11.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
|
||||||
dependencies = [
|
|
||||||
"unicode-width",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.40"
|
version = "1.0.40"
|
||||||
|
@ -2125,15 +2187,20 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "time"
|
name = "time"
|
||||||
version = "0.1.45"
|
version = "0.3.20"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a"
|
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"serde",
|
||||||
"wasi 0.10.0+wasi-snapshot-preview1",
|
"time-core",
|
||||||
"winapi",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "time-core"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinyvec"
|
name = "tinyvec"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
@ -2151,13 +2218,13 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.13.2"
|
version = "0.13.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69"
|
checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"cached-path",
|
"cached-path",
|
||||||
"clap 2.34.0",
|
"clap",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
"dirs",
|
"dirs",
|
||||||
"esaxx-rs",
|
"esaxx-rs",
|
||||||
|
@ -2167,6 +2234,7 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"log",
|
"log",
|
||||||
"macro_rules_attribute",
|
"macro_rules_attribute",
|
||||||
|
"monostate",
|
||||||
"onig",
|
"onig",
|
||||||
"paste",
|
"paste",
|
||||||
"rand",
|
"rand",
|
||||||
|
@ -2535,12 +2603,6 @@ version = "0.2.15"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "vec_map"
|
|
||||||
version = "0.8.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "version_check"
|
name = "version_check"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
|
@ -2557,12 +2619,6 @@ dependencies = [
|
||||||
"try-lock",
|
"try-lock",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "wasi"
|
|
||||||
version = "0.10.0+wasi-snapshot-preview1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.11.0+wasi-snapshot-preview1"
|
version = "0.11.0+wasi-snapshot-preview1"
|
||||||
|
@ -2779,23 +2835,50 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zip"
|
name = "zip"
|
||||||
version = "0.5.13"
|
version = "0.6.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815"
|
checksum = "0445d0fbc924bb93539b4316c11afb121ea39296f99a3c4c9edad09e3658cdef"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aes",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"bzip2",
|
"bzip2",
|
||||||
|
"constant_time_eq",
|
||||||
"crc32fast",
|
"crc32fast",
|
||||||
|
"crossbeam-utils",
|
||||||
"flate2",
|
"flate2",
|
||||||
"thiserror",
|
"hmac",
|
||||||
|
"pbkdf2",
|
||||||
|
"sha1",
|
||||||
"time",
|
"time",
|
||||||
|
"zstd",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zip-extensions"
|
name = "zstd"
|
||||||
version = "0.6.1"
|
version = "0.11.2+zstd.1.5.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a64c3c977bc3434ce2d4bcea8ad3c644672de0f2c402b72b9171ca80a8885d14"
|
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"zip",
|
"zstd-safe",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd-safe"
|
||||||
|
version = "5.0.2+zstd.1.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"zstd-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd-sys"
|
||||||
|
version = "2.0.8+zstd.1.5.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"libc",
|
||||||
|
"pkg-config",
|
||||||
]
|
]
|
||||||
|
|
|
@ -27,7 +27,7 @@ serde = {version = "1.0.142", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
text-generation-client = { path = "../router/client" }
|
text-generation-client = { path = "../router/client" }
|
||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
tokenizers = "0.13.2"
|
tokenizers = "0.13.3"
|
||||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
|
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
|
|
@ -75,7 +75,7 @@ async fn generate_runs(
|
||||||
// Warmups on batch size
|
// Warmups on batch size
|
||||||
for _ in 0..warmups {
|
for _ in 0..warmups {
|
||||||
let (_, decode_batch) =
|
let (_, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||||
let _ = decode(decode_batch, &mut client).await?;
|
let _ = decode(decode_batch, &mut client).await?;
|
||||||
// Send warmup message
|
// Send warmup message
|
||||||
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||||
|
@ -83,7 +83,7 @@ async fn generate_runs(
|
||||||
|
|
||||||
for _ in 0..n_runs {
|
for _ in 0..n_runs {
|
||||||
let (prefill, decode_batch) =
|
let (prefill, decode_batch) =
|
||||||
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?;
|
||||||
// Send prefill message
|
// Send prefill message
|
||||||
run_sender
|
run_sender
|
||||||
.send(Ok(Message::Prefill(prefill)))
|
.send(Ok(Message::Prefill(prefill)))
|
||||||
|
@ -110,6 +110,7 @@ async fn generate_runs(
|
||||||
// Run a prefill step
|
// Run a prefill step
|
||||||
async fn prefill(
|
async fn prefill(
|
||||||
sequence: String,
|
sequence: String,
|
||||||
|
sequence_length: u32,
|
||||||
batch_size: u32,
|
batch_size: u32,
|
||||||
decode_length: u32,
|
decode_length: u32,
|
||||||
client: &mut ShardedClient,
|
client: &mut ShardedClient,
|
||||||
|
@ -119,6 +120,7 @@ async fn prefill(
|
||||||
.map(|id| Request {
|
.map(|id| Request {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
inputs: sequence.clone(),
|
inputs: sequence.clone(),
|
||||||
|
truncate: sequence_length,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
|
|
@ -63,10 +63,12 @@ message Request {
|
||||||
uint64 id = 1;
|
uint64 id = 1;
|
||||||
/// The generation context
|
/// The generation context
|
||||||
string inputs = 2;
|
string inputs = 2;
|
||||||
|
/// Context truncation
|
||||||
|
uint32 truncate = 3;
|
||||||
/// Next Token Chooser Parameters
|
/// Next Token Chooser Parameters
|
||||||
NextTokenChooserParameters parameters = 3;
|
NextTokenChooserParameters parameters = 4;
|
||||||
/// Stopping Criteria Parameters
|
/// Stopping Criteria Parameters
|
||||||
StoppingCriteriaParameters stopping_parameters = 4;
|
StoppingCriteriaParameters stopping_parameters = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
|
|
|
@ -18,21 +18,20 @@ axum = { version = "0.6.4", features = ["json"] }
|
||||||
axum-tracing-opentelemetry = "0.9.0"
|
axum-tracing-opentelemetry = "0.9.0"
|
||||||
text-generation-client = { path = "client" }
|
text-generation-client = { path = "client" }
|
||||||
clap = { version = "4.1.4", features = ["derive", "env"] }
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
|
flume = "0.10.14"
|
||||||
futures = "0.3.26"
|
futures = "0.3.26"
|
||||||
metrics = "0.20.1"
|
metrics = "0.20.1"
|
||||||
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
metrics-exporter-prometheus = { version = "0.11.0", features = [] }
|
||||||
nohash-hasher = "0.2.0"
|
nohash-hasher = "0.2.0"
|
||||||
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
|
||||||
opentelemetry-otlp = "0.11.0"
|
opentelemetry-otlp = "0.11.0"
|
||||||
parking_lot = "0.12.1"
|
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.11.14", features = [] }
|
reqwest = { version = "0.11.14", features = [] }
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
serde_json = "1.0.93"
|
serde_json = "1.0.93"
|
||||||
thiserror = "1.0.38"
|
thiserror = "1.0.38"
|
||||||
tokenizers = "0.13.2"
|
tokenizers = "0.13.3"
|
||||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.11"
|
|
||||||
tower-http = { version = "0.3.5", features = ["cors"] }
|
tower-http = { version = "0.3.5", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
|
|
|
@ -2,17 +2,17 @@
|
||||||
use crate::validation::{Validation, ValidationError};
|
use crate::validation::{Validation, ValidationError};
|
||||||
use crate::{Entry, Queue, Token};
|
use crate::{Entry, Queue, Token};
|
||||||
use crate::{GenerateRequest, PrefillToken};
|
use crate::{GenerateRequest, PrefillToken};
|
||||||
|
use flume::r#async::RecvStream;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use futures::stream::StreamExt;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
Batch, ClientError, GeneratedText, Generation, PrefillTokens, ShardedClient,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
|
use tokio::sync::{Notify, Semaphore, TryAcquireError};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tokio_stream::StreamExt;
|
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
|
@ -73,7 +73,7 @@ impl Infer {
|
||||||
pub(crate) async fn generate_stream(
|
pub(crate) async fn generate_stream(
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
) -> Result<RecvStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
// This permit will live as long as Entry
|
// This permit will live as long as Entry
|
||||||
let permit = self
|
let permit = self
|
||||||
|
@ -87,10 +87,14 @@ impl Infer {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Validate request
|
// Validate request
|
||||||
let valid_request = self.validation.validate(request).await?;
|
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||||
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
})?;
|
||||||
|
|
||||||
// MPSC channel to communicate with the background batching task
|
// MPSC channel to communicate with the background batching task
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = flume::unbounded();
|
||||||
|
|
||||||
// Append the request to the queue
|
// Append the request to the queue
|
||||||
self.queue.append(Entry {
|
self.queue.append(Entry {
|
||||||
|
@ -108,7 +112,7 @@ impl Infer {
|
||||||
self.shared.batching_task.notify_one();
|
self.shared.batching_task.notify_one();
|
||||||
|
|
||||||
// Return stream
|
// Return stream
|
||||||
Ok(UnboundedReceiverStream::new(response_rx))
|
Ok(response_rx.into_stream())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new request to the queue and return a InferResponse
|
/// Add a new request to the queue and return a InferResponse
|
||||||
|
|
|
@ -37,7 +37,7 @@ struct Args {
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
@ -94,11 +94,11 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
{
|
{
|
||||||
// Load local tokenizer
|
// Load local tokenizer
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
Tokenizer::from_file(local_path.join("tokenizer.json")).ok()
|
||||||
} else {
|
} else {
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).ok()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -109,6 +109,13 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
|
if tokenizer.is_none() {
|
||||||
|
tracing::warn!(
|
||||||
|
"Could not find a fast tokenizer implementation for {tokenizer_name}"
|
||||||
|
);
|
||||||
|
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||||
|
}
|
||||||
|
|
||||||
// Get pipeline tag
|
// Get pipeline tag
|
||||||
let model_info = reqwest::get(format!(
|
let model_info = reqwest::get(format!(
|
||||||
"https://huggingface.co/api/models/{tokenizer_name}"
|
"https://huggingface.co/api/models/{tokenizer_name}"
|
||||||
|
|
|
@ -4,8 +4,7 @@ use crate::validation::ValidGenerateRequest;
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use text_generation_client::{Batch, Request};
|
use text_generation_client::{Batch, Request};
|
||||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
use tokio::sync::{oneshot, OwnedSemaphorePermit};
|
||||||
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
|
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
|
||||||
|
@ -15,7 +14,7 @@ pub(crate) struct Entry {
|
||||||
/// Request
|
/// Request
|
||||||
pub request: ValidGenerateRequest,
|
pub request: ValidGenerateRequest,
|
||||||
/// Response sender to communicate between the Infer struct and the batching_task
|
/// Response sender to communicate between the Infer struct and the batching_task
|
||||||
pub response_tx: UnboundedSender<Result<InferStreamResponse, InferError>>,
|
pub response_tx: flume::Sender<Result<InferStreamResponse, InferError>>,
|
||||||
/// Span that will live as long as entry
|
/// Span that will live as long as entry
|
||||||
pub span: Span,
|
pub span: Span,
|
||||||
/// Temporary span used as a guard when logging inference, wait times...
|
/// Temporary span used as a guard when logging inference, wait times...
|
||||||
|
@ -32,13 +31,13 @@ pub(crate) struct Entry {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(crate) struct Queue {
|
pub(crate) struct Queue {
|
||||||
/// Channel to communicate with the background queue task
|
/// Channel to communicate with the background queue task
|
||||||
queue_sender: UnboundedSender<QueueCommand>,
|
queue_sender: flume::Sender<QueueCommand>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Queue {
|
impl Queue {
|
||||||
pub(crate) fn new() -> Self {
|
pub(crate) fn new() -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = flume::unbounded();
|
||||||
|
|
||||||
// Launch background queue task
|
// Launch background queue task
|
||||||
tokio::spawn(queue_task(queue_receiver));
|
tokio::spawn(queue_task(queue_receiver));
|
||||||
|
@ -82,10 +81,10 @@ impl Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
async fn queue_task(mut receiver: UnboundedReceiver<QueueCommand>) {
|
async fn queue_task(receiver: flume::Receiver<QueueCommand>) {
|
||||||
let mut state = State::new();
|
let mut state = State::new();
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
QueueCommand::Append(entry, span) => span.in_scope(|| state.append(entry)),
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
|
@ -174,6 +173,7 @@ impl State {
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
inputs: entry.request.inputs.clone(),
|
inputs: entry.request.inputs.clone(),
|
||||||
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(entry.request.parameters.clone()),
|
parameters: Some(entry.request.parameters.clone()),
|
||||||
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
|
||||||
});
|
});
|
||||||
|
@ -215,17 +215,18 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use tokio::sync::{mpsc, Semaphore};
|
use tokio::sync::Semaphore;
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> Entry {
|
fn default_entry() -> Entry {
|
||||||
let semaphore = Arc::new(Semaphore::new(1));
|
let semaphore = Arc::new(Semaphore::new(1));
|
||||||
let (response_tx, _) = mpsc::unbounded_channel();
|
let (response_tx, _) = flume::unbounded();
|
||||||
let permit = semaphore.try_acquire_owned().unwrap();
|
let permit = semaphore.try_acquire_owned().unwrap();
|
||||||
|
|
||||||
Entry {
|
Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: "".to_string(),
|
inputs: "".to_string(),
|
||||||
|
truncate: 0,
|
||||||
parameters: NextTokenChooserParameters {
|
parameters: NextTokenChooserParameters {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
top_k: 0,
|
top_k: 0,
|
||||||
|
|
|
@ -13,6 +13,7 @@ use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
|
use futures::stream::StreamExt;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
|
@ -21,7 +22,6 @@ use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
|
@ -87,21 +87,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
|
||||||
|
|
||||||
/// Generate tokens
|
/// Generate tokens
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate",
|
path = "/generate",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = GenerateResponse),
|
(status = 200, description = "Generated Text", body = GenerateResponse),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"})),
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"})),
|
example = json ! ({"error": "Input validation error"})),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"})),
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
|
@ -264,26 +264,26 @@ async fn generate(
|
||||||
|
|
||||||
/// Generate a stream of token using Server-Sent Events
|
/// Generate a stream of token using Server-Sent Events
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/generate_stream",
|
path = "/generate_stream",
|
||||||
request_body = GenerateRequest,
|
request_body = GenerateRequest,
|
||||||
responses(
|
responses(
|
||||||
(status = 200, description = "Generated Text", body = StreamResponse,
|
(status = 200, description = "Generated Text", body = StreamResponse,
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Request failed during generation"}),
|
example = json ! ({"error": "Request failed during generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Model is overloaded"}),
|
example = json ! ({"error": "Model is overloaded"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Input validation error"}),
|
example = json ! ({"error": "Input validation error"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
example = json ! ({"error": "Incomplete generation"}),
|
example = json ! ({"error": "Incomplete generation"}),
|
||||||
content_type = "text/event-stream"),
|
content_type = "text/event-stream"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(infer),
|
skip(infer),
|
||||||
|
@ -447,10 +447,10 @@ async fn generate_stream(
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
path = "/metrics",
|
path = "/metrics",
|
||||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||||
)]
|
)]
|
||||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||||
prom_handle.render()
|
prom_handle.render()
|
||||||
|
@ -468,7 +468,7 @@ pub async fn run(
|
||||||
max_batch_size: usize,
|
max_batch_size: usize,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
client: ShardedClient,
|
client: ShardedClient,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Option<Tokenizer>,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
@ -476,36 +476,36 @@ pub async fn run(
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
paths(
|
paths(
|
||||||
generate,
|
generate,
|
||||||
generate_stream,
|
generate_stream,
|
||||||
metrics,
|
metrics,
|
||||||
),
|
),
|
||||||
components(
|
components(
|
||||||
schemas(
|
schemas(
|
||||||
GenerateRequest,
|
GenerateRequest,
|
||||||
GenerateParameters,
|
GenerateParameters,
|
||||||
PrefillToken,
|
PrefillToken,
|
||||||
Token,
|
Token,
|
||||||
GenerateResponse,
|
GenerateResponse,
|
||||||
BestOfSequence,
|
BestOfSequence,
|
||||||
Details,
|
Details,
|
||||||
FinishReason,
|
FinishReason,
|
||||||
StreamResponse,
|
StreamResponse,
|
||||||
StreamDetails,
|
StreamDetails,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
tags(
|
tags(
|
||||||
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
|
||||||
),
|
),
|
||||||
info(
|
info(
|
||||||
title = "Text Generation Inference",
|
title = "Text Generation Inference",
|
||||||
license(
|
license(
|
||||||
name = "Apache 2.0",
|
name = "Apache 2.0",
|
||||||
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
url = "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
|
|
||||||
|
|
|
@ -1,50 +1,129 @@
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::rngs::ThreadRng;
|
use rand::{thread_rng, Rng};
|
||||||
use rand::Rng;
|
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::oneshot;
|
||||||
use tracing::{instrument, Span};
|
use tracing::{instrument, Span};
|
||||||
|
|
||||||
/// Validation
|
/// Validation
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Validation {
|
pub struct Validation {
|
||||||
/// maximum value for the best_of parameter
|
/// Validation parameters
|
||||||
#[allow(dead_code)]
|
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
/// Channel to communicate with the background validation task
|
max_stop_sequences: usize,
|
||||||
sender: mpsc::UnboundedSender<ValidationRequest>,
|
max_input_length: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
/// Channel to communicate with the background tokenization task
|
||||||
|
sender: Option<flume::Sender<TokenizerRequest>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Validation {
|
impl Validation {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
workers: usize,
|
workers: usize,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Option<Tokenizer>,
|
||||||
max_best_of: usize,
|
max_best_of: usize,
|
||||||
max_stop_sequences: usize,
|
max_stop_sequences: usize,
|
||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
if max_input_length >= max_total_tokens {
|
||||||
let (validation_sender, validation_receiver) = mpsc::unbounded_channel();
|
panic!("`max_input_length` must be < `max_total_tokens`");
|
||||||
|
}
|
||||||
|
|
||||||
// Launch background validation task
|
// If we have a fast tokenizer
|
||||||
tokio::spawn(validation_task(
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
workers,
|
// Create channel
|
||||||
tokenizer,
|
let (validation_sender, validation_receiver) = flume::unbounded();
|
||||||
max_stop_sequences,
|
|
||||||
max_input_length,
|
// Create workers
|
||||||
max_total_tokens,
|
for _ in 0..workers {
|
||||||
validation_receiver,
|
let tokenizer_clone = tokenizer.clone();
|
||||||
));
|
let receiver_clone = validation_receiver.clone();
|
||||||
|
|
||||||
|
// Spawn worker
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
tokenizer_worker(tokenizer_clone, receiver_clone)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Some(validation_sender)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
max_best_of,
|
max_best_of,
|
||||||
sender: validation_sender,
|
sender,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn validate_input(
|
||||||
|
&self,
|
||||||
|
inputs: String,
|
||||||
|
truncate: Option<usize>,
|
||||||
|
max_new_tokens: u32,
|
||||||
|
) -> Result<String, ValidationError> {
|
||||||
|
// If we have a fast tokenizer
|
||||||
|
if let Some(sender) = &self.sender {
|
||||||
|
// Create response channel
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
// Send request to the background validation task
|
||||||
|
// Unwrap is safe here
|
||||||
|
sender
|
||||||
|
.send(((inputs, truncate), response_sender, Span::current()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Await on response channel
|
||||||
|
// Unwrap is safe here
|
||||||
|
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||||
|
|
||||||
|
// Get total tokens
|
||||||
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
|
// Validate MaxTotalTokens
|
||||||
|
if total_tokens > self.max_total_tokens {
|
||||||
|
return Err(ValidationError::MaxTotalTokens(
|
||||||
|
self.max_total_tokens,
|
||||||
|
input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate InputLength
|
||||||
|
if input_length > self.max_input_length {
|
||||||
|
return Err(ValidationError::InputLength(
|
||||||
|
self.max_input_length,
|
||||||
|
input_length,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
Ok(inputs)
|
||||||
|
}
|
||||||
|
// Return inputs without validation
|
||||||
|
else {
|
||||||
|
// In this case, we don't know the real length in tokens of the inputs
|
||||||
|
// However, the inputs will be truncated by the python servers
|
||||||
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
|
|
||||||
|
// Validate MaxNewTokens
|
||||||
|
if (truncate.unwrap_or(self.max_input_length) as u32 + max_new_tokens)
|
||||||
|
> self.max_total_tokens as u32
|
||||||
|
{
|
||||||
|
return Err(ValidationError::MaxNewTokens(
|
||||||
|
self.max_total_tokens - self.max_input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(inputs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,16 +133,139 @@ impl Validation {
|
||||||
&self,
|
&self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
) -> Result<ValidGenerateRequest, ValidationError> {
|
||||||
// Create response channel
|
let GenerateParameters {
|
||||||
let (sender, receiver) = oneshot::channel();
|
best_of,
|
||||||
// Send request to the background validation task
|
temperature,
|
||||||
// Unwrap is safe here
|
repetition_penalty,
|
||||||
self.sender
|
top_k,
|
||||||
.send((request, sender, Span::current()))
|
top_p,
|
||||||
.unwrap();
|
typical_p,
|
||||||
// Await on response channel
|
do_sample,
|
||||||
// Unwrap is safe here
|
max_new_tokens,
|
||||||
receiver.await.unwrap()
|
stop: stop_sequences,
|
||||||
|
truncate,
|
||||||
|
seed,
|
||||||
|
watermark,
|
||||||
|
..
|
||||||
|
} = request.parameters;
|
||||||
|
|
||||||
|
// sampling must be true when best_of > 1
|
||||||
|
let best_of = best_of.unwrap_or(1);
|
||||||
|
let sampling = do_sample
|
||||||
|
|| temperature.is_some()
|
||||||
|
|| top_k.is_some()
|
||||||
|
|| top_p.is_some()
|
||||||
|
|| typical_p.is_some();
|
||||||
|
|
||||||
|
if best_of > 1 && !sampling {
|
||||||
|
return Err(BestOfSampling);
|
||||||
|
}
|
||||||
|
|
||||||
|
let temperature = temperature.unwrap_or(1.0);
|
||||||
|
if temperature <= 0.0 {
|
||||||
|
return Err(ValidationError::Temperature);
|
||||||
|
}
|
||||||
|
|
||||||
|
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
|
||||||
|
if repetition_penalty <= 0.0 {
|
||||||
|
return Err(ValidationError::RepetitionPenalty);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different because the proto default value is not a valid value
|
||||||
|
// for the user
|
||||||
|
let top_p = top_p
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0.0 || value >= 1.0 {
|
||||||
|
return Err(ValidationError::TopP);
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
|
let typical_p = typical_p
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0.0 || value >= 1.0 {
|
||||||
|
return Err(ValidationError::TypicalP);
|
||||||
|
}
|
||||||
|
Ok(value)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(1.0))?;
|
||||||
|
|
||||||
|
let top_k: u32 = top_k
|
||||||
|
.map(|value| {
|
||||||
|
if value <= 0 {
|
||||||
|
return Err(ValidationError::TopK);
|
||||||
|
}
|
||||||
|
Ok(value as u32)
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
|
if max_new_tokens == 0 {
|
||||||
|
return Err(ValidationError::NegativeMaxNewTokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop_sequences.len() > self.max_stop_sequences {
|
||||||
|
return Err(ValidationError::StopSequence(
|
||||||
|
self.max_stop_sequences,
|
||||||
|
stop_sequences.len(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If seed is None, assign a random one
|
||||||
|
let seed = match seed {
|
||||||
|
None => thread_rng().gen(),
|
||||||
|
Some(seed) => {
|
||||||
|
if best_of > 1 {
|
||||||
|
return Err(BestOfSeed);
|
||||||
|
}
|
||||||
|
seed
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if inputs is empty
|
||||||
|
if request.inputs.is_empty() {
|
||||||
|
return Err(EmptyInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if truncate is strictly positive and less than max_input_length
|
||||||
|
let truncate = truncate
|
||||||
|
.map(|value| {
|
||||||
|
if value == 0 || value > self.max_input_length {
|
||||||
|
return Err(ValidationError::Truncate(self.max_input_length, value));
|
||||||
|
}
|
||||||
|
Ok(Some(value))
|
||||||
|
})
|
||||||
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
|
// Validate inputs
|
||||||
|
let inputs = self
|
||||||
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let parameters = NextTokenChooserParameters {
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
typical_p,
|
||||||
|
do_sample,
|
||||||
|
seed,
|
||||||
|
watermark,
|
||||||
|
};
|
||||||
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
|
max_new_tokens,
|
||||||
|
stop_sequences,
|
||||||
|
ignore_eos_token: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||||
|
|
||||||
|
Ok(ValidGenerateRequest {
|
||||||
|
inputs,
|
||||||
|
truncate: truncate.unwrap_or(self.max_input_length) as u32,
|
||||||
|
parameters,
|
||||||
|
stopping_parameters,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate the best_of parameter
|
/// Validate the best_of parameter
|
||||||
|
@ -81,262 +283,57 @@ impl Validation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validation task
|
/// Start tokenization workers
|
||||||
/// Load balance the validation requests between multiple validation workers
|
fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) {
|
||||||
async fn validation_task(
|
|
||||||
workers: usize,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
max_input_length: usize,
|
|
||||||
max_total_tokens: usize,
|
|
||||||
mut receiver: mpsc::UnboundedReceiver<ValidationRequest>,
|
|
||||||
) {
|
|
||||||
let mut workers_senders = Vec::with_capacity(workers);
|
|
||||||
|
|
||||||
// Create workers
|
|
||||||
for _ in 0..workers {
|
|
||||||
let tokenizer_clone: Tokenizer = tokenizer.clone().into();
|
|
||||||
// Create channel to communicate with worker
|
|
||||||
let (worker_sender, worker_receiver) = mpsc::channel(workers);
|
|
||||||
workers_senders.push(worker_sender);
|
|
||||||
|
|
||||||
// Spawn worker
|
|
||||||
tokio::task::spawn_blocking(move || {
|
|
||||||
validation_worker(
|
|
||||||
tokenizer_clone,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
worker_receiver,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
loop {
|
|
||||||
// Load balance requests between workers
|
|
||||||
for sender in workers_senders.iter() {
|
|
||||||
if let Some(validation_request) = receiver.recv().await {
|
|
||||||
sender.send(validation_request).await.unwrap();
|
|
||||||
} else {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check the parameters inside the payload and get the number of tokens inside the input using
|
|
||||||
/// the tokenizer
|
|
||||||
fn validation_worker(
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
max_stop_sequences: usize,
|
|
||||||
max_input_length: usize,
|
|
||||||
max_total_tokens: usize,
|
|
||||||
mut receiver: mpsc::Receiver<ValidationRequest>,
|
|
||||||
) {
|
|
||||||
// Seed rng
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
|
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Some((request, response_tx, parent_span)) = receiver.blocking_recv() {
|
while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() {
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(
|
.send(prepare_input(inputs, truncate, &tokenizer))
|
||||||
validate(
|
|
||||||
request,
|
|
||||||
&tokenizer,
|
|
||||||
max_stop_sequences,
|
|
||||||
max_input_length,
|
|
||||||
max_total_tokens,
|
|
||||||
&mut rng,
|
|
||||||
)
|
|
||||||
.map_err(|err| {
|
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
|
||||||
tracing::error!("{err}");
|
|
||||||
err
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.unwrap_or(())
|
.unwrap_or(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate(
|
/// Get input length and optionally truncate it
|
||||||
request: GenerateRequest,
|
fn prepare_input(
|
||||||
|
inputs: String,
|
||||||
|
truncate: Option<usize>,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
max_stop_sequences: usize,
|
) -> Result<(String, usize), ValidationError> {
|
||||||
max_input_length: usize,
|
|
||||||
max_total_tokens: usize,
|
|
||||||
rng: &mut ThreadRng,
|
|
||||||
) -> Result<ValidGenerateRequest, ValidationError> {
|
|
||||||
let GenerateParameters {
|
|
||||||
best_of,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
typical_p,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
stop: stop_sequences,
|
|
||||||
truncate,
|
|
||||||
seed,
|
|
||||||
watermark,
|
|
||||||
..
|
|
||||||
} = request.parameters;
|
|
||||||
|
|
||||||
// sampling must be true when best_of > 1
|
|
||||||
let best_of = best_of.unwrap_or(1);
|
|
||||||
let sampling = do_sample
|
|
||||||
|| temperature.is_some()
|
|
||||||
|| top_k.is_some()
|
|
||||||
|| top_p.is_some()
|
|
||||||
|| typical_p.is_some();
|
|
||||||
|
|
||||||
if best_of > 1 && !sampling {
|
|
||||||
return Err(BestOfSampling);
|
|
||||||
}
|
|
||||||
|
|
||||||
let temperature = temperature.unwrap_or(1.0);
|
|
||||||
if temperature <= 0.0 {
|
|
||||||
return Err(ValidationError::Temperature);
|
|
||||||
}
|
|
||||||
|
|
||||||
let repetition_penalty = repetition_penalty.unwrap_or(1.0);
|
|
||||||
if repetition_penalty <= 0.0 {
|
|
||||||
return Err(ValidationError::RepetitionPenalty);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Different because the proto default value is not a valid value
|
|
||||||
// for the user
|
|
||||||
let top_p = top_p
|
|
||||||
.map(|value| {
|
|
||||||
if value <= 0.0 || value >= 1.0 {
|
|
||||||
return Err(ValidationError::TopP);
|
|
||||||
}
|
|
||||||
Ok(value)
|
|
||||||
})
|
|
||||||
.unwrap_or(Ok(1.0))?;
|
|
||||||
|
|
||||||
let typical_p = typical_p
|
|
||||||
.map(|value| {
|
|
||||||
if value <= 0.0 || value >= 1.0 {
|
|
||||||
return Err(ValidationError::TypicalP);
|
|
||||||
}
|
|
||||||
Ok(value)
|
|
||||||
})
|
|
||||||
.unwrap_or(Ok(1.0))?;
|
|
||||||
|
|
||||||
let top_k: u32 = top_k
|
|
||||||
.map(|value| {
|
|
||||||
if value <= 0 {
|
|
||||||
return Err(ValidationError::TopK);
|
|
||||||
}
|
|
||||||
Ok(value as u32)
|
|
||||||
})
|
|
||||||
.unwrap_or(Ok(0))?;
|
|
||||||
|
|
||||||
if max_new_tokens == 0 {
|
|
||||||
return Err(ValidationError::MaxNewTokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
if stop_sequences.len() > max_stop_sequences {
|
|
||||||
return Err(ValidationError::StopSequence(
|
|
||||||
max_stop_sequences,
|
|
||||||
stop_sequences.len(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If seed is None, assign a random one
|
|
||||||
let seed = match seed {
|
|
||||||
None => rng.gen(),
|
|
||||||
Some(seed) => {
|
|
||||||
if best_of > 1 {
|
|
||||||
return Err(BestOfSeed);
|
|
||||||
}
|
|
||||||
seed
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if inputs is empty
|
|
||||||
if request.inputs.is_empty() {
|
|
||||||
return Err(EmptyInput);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if truncate is strictly positive and less than max_input_length
|
|
||||||
let truncate = truncate
|
|
||||||
.map(|value| {
|
|
||||||
if value == 0 || value > max_input_length {
|
|
||||||
return Err(ValidationError::Truncate(max_input_length, value));
|
|
||||||
}
|
|
||||||
Ok(Some(value))
|
|
||||||
})
|
|
||||||
.unwrap_or(Ok(None))?;
|
|
||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let mut encoding = tokenizer
|
let mut encoding = tokenizer
|
||||||
.encode(request.inputs.clone(), true)
|
.encode(inputs.clone(), true)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
let (inputs, input_length) = if let Some(truncate) = truncate {
|
// Optionally truncate
|
||||||
// truncate encoding and decode new inputs
|
let (inputs, input_length) = match truncate {
|
||||||
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
// Truncate is some and > encoding length
|
||||||
let inputs = tokenizer
|
Some(truncate) if truncate > encoding.len() => {
|
||||||
.decode(Vec::from(encoding.get_ids()), false)
|
// truncate encoding and decode new inputs
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
encoding.truncate(truncate, 0, TruncationDirection::Left);
|
||||||
(inputs, encoding.len())
|
let inputs = tokenizer
|
||||||
} else {
|
.decode(Vec::from(encoding.get_ids()), false)
|
||||||
(request.inputs, encoding.len())
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
(inputs, encoding.len())
|
||||||
|
}
|
||||||
|
// Nothing to do
|
||||||
|
_ => (inputs, encoding.len()),
|
||||||
};
|
};
|
||||||
|
|
||||||
if input_length > max_input_length {
|
Ok((inputs, input_length))
|
||||||
return Err(ValidationError::InputLength(max_input_length, input_length));
|
|
||||||
}
|
|
||||||
|
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
|
||||||
if total_tokens > max_total_tokens {
|
|
||||||
return Err(ValidationError::MaxTotalTokens(
|
|
||||||
max_total_tokens,
|
|
||||||
input_length,
|
|
||||||
max_new_tokens,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return ValidGenerateRequest
|
|
||||||
let parameters = NextTokenChooserParameters {
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
typical_p,
|
|
||||||
do_sample,
|
|
||||||
seed,
|
|
||||||
watermark,
|
|
||||||
};
|
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
|
||||||
max_new_tokens,
|
|
||||||
stop_sequences,
|
|
||||||
ignore_eos_token: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
|
||||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
|
||||||
|
|
||||||
Ok(ValidGenerateRequest {
|
|
||||||
inputs,
|
|
||||||
parameters,
|
|
||||||
stopping_parameters,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValidationRequest = (
|
type TokenizerRequest = (
|
||||||
GenerateRequest,
|
(String, Option<usize>),
|
||||||
oneshot::Sender<Result<ValidGenerateRequest, ValidationError>>,
|
oneshot::Sender<Result<(String, usize), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct ValidGenerateRequest {
|
pub(crate) struct ValidGenerateRequest {
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
|
pub truncate: u32,
|
||||||
pub parameters: NextTokenChooserParameters,
|
pub parameters: NextTokenChooserParameters,
|
||||||
pub stopping_parameters: StoppingCriteriaParameters,
|
pub stopping_parameters: StoppingCriteriaParameters,
|
||||||
}
|
}
|
||||||
|
@ -366,7 +363,9 @@ pub enum ValidationError {
|
||||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
TypicalP,
|
TypicalP,
|
||||||
#[error("`max_new_tokens` must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
MaxNewTokens,
|
NegativeMaxNewTokens,
|
||||||
|
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
||||||
|
MaxNewTokens(usize, u32),
|
||||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||||
MaxTotalTokens(usize, usize, u32),
|
MaxTotalTokens(usize, usize, u32),
|
||||||
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||||
|
|
|
@ -24,6 +24,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,6 +25,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="def",
|
inputs="def",
|
||||||
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
@ -30,6 +31,7 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
inputs="<fim-prefix>def<fim-suffix>world<fim-middle>",
|
||||||
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,6 +28,7 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
|
||||||
return generate_pb2.Request(
|
return generate_pb2.Request(
|
||||||
id=0,
|
id=0,
|
||||||
inputs="Test",
|
inputs="Test",
|
||||||
|
truncate=100,
|
||||||
parameters=default_pb_parameters,
|
parameters=default_pb_parameters,
|
||||||
stopping_parameters=default_pb_stop_parameters,
|
stopping_parameters=default_pb_stop_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
@ -68,7 +68,7 @@ class BLOOMSharded(BLOOM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
|
|
|
@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
|
@ -74,6 +75,7 @@ class CausalLMBatch(Batch):
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -83,6 +85,8 @@ class CausalLMBatch(Batch):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
|
|
@ -38,7 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 6144:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -624,13 +624,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
|
|
||||||
|
|
||||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config, process_group=None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if config.tp_parallel:
|
self.process_group = process_group
|
||||||
process_group = torch.distributed.distributed_c10d._get_default_group()
|
if self.process_group is not None:
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
else:
|
else:
|
||||||
process_group = None
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||||
|
|
||||||
|
@ -668,4 +671,13 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
hidden_states, present = self.gpt_neox(
|
hidden_states, present = self.gpt_neox(
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
||||||
)
|
)
|
||||||
return self.embed_out(hidden_states), present
|
logits = self.embed_out(hidden_states)
|
||||||
|
|
||||||
|
if self.gpt_neox.tp_embeddings:
|
||||||
|
# Logits are sharded, so we need to gather them
|
||||||
|
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||||
|
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||||
|
world_logits = torch.cat(world_logits, dim=1)
|
||||||
|
|
||||||
|
return world_logits, present
|
||||||
|
return logits, present
|
||||||
|
|
|
@ -11,7 +11,7 @@ import dropout_layer_norm
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 6144:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
|
@ -78,7 +78,9 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
tokenized_input = tokenizer(r.inputs)["input_ids"]
|
tokenized_input = tokenizer(
|
||||||
|
r.inputs, truncation=True, max_length=r.truncate
|
||||||
|
)["input_ids"]
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
@ -208,7 +210,7 @@ class FlashCausalLM(Model):
|
||||||
raise NotImplementedError("FlashCausalLM does not support quantization")
|
raise NotImplementedError("FlashCausalLM does not support quantization")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
self.model = (
|
self.model = (
|
||||||
model_cls.from_pretrained(
|
model_cls.from_pretrained(
|
||||||
|
|
|
@ -45,18 +45,19 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
raise NotImplementedError("FlashNeoX does not support quantization")
|
raise NotImplementedError("FlashNeoX does not support quantization")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, revision=revision, tp_parallel=True
|
model_id,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = FlashGPTNeoXForCausalLM(config)
|
model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
|
@ -147,32 +148,3 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
position_ids: torch.Tensor,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
max_s: int,
|
|
||||||
past_key_values: Optional = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if self.model.gpt_neox.tp_embeddings:
|
|
||||||
logits, present = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_s=max_s,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Logits are sharded, so we need to gather them
|
|
||||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
|
||||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
|
||||||
world_logits = torch.cat(world_logits, dim=1)
|
|
||||||
|
|
||||||
return world_logits, present
|
|
||||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
|
||||||
else:
|
|
||||||
return super(FlashNeoXSharded, self).forward(
|
|
||||||
input_ids, position_ids, cu_seqlens, max_s, past_key_values
|
|
||||||
)
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
raise NotImplementedError("FlashSantacoder does not support quantization")
|
raise NotImplementedError("FlashSantacoder does not support quantization")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
|
@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
model,
|
model,
|
||||||
filenames,
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
self.model = model.eval().to(device).to(dtype)
|
self.model = model.eval().to(device).to(dtype)
|
||||||
|
|
||||||
|
@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model: FlashSantacoderForCausalLM,
|
model: FlashSantacoderForCausalLM,
|
||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
|
value = value.to(device).to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
# Fused qkv
|
# Fused qkv
|
||||||
|
@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = value
|
module._buffers[param_name] = value
|
||||||
|
|
||||||
|
del value
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights()
|
model.post_load_weights()
|
||||||
|
|
||||||
|
|
|
@ -96,7 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
max_sequence_length = 0
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
# Add escape_custom_split_sequence to the CausalLMBatch logic
|
||||||
|
@ -107,7 +107,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
max_sequence_length = max(max_sequence_length, r.input_length)
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -118,14 +118,20 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
max_input_length = input_lengths.max()
|
||||||
|
|
||||||
input_ids = tokenized_inputs["input_ids"]
|
input_ids = tokenized_inputs["input_ids"]
|
||||||
# Allocate maximum attention_mask
|
# Allocate maximum attention_mask
|
||||||
attention_mask = input_ids.new_zeros(
|
attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_sequence_length + padding_right_offset)
|
(pb.size, max_input_length + padding_right_offset)
|
||||||
)
|
)
|
||||||
# Copy tokenizer attention_mask into fully allocated attention_mask
|
# Copy tokenizer attention_mask into fully allocated attention_mask
|
||||||
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
|
@ -143,7 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
size=pb.size,
|
size=pb.size,
|
||||||
max_sequence_length=max_sequence_length,
|
max_input_length=max_input_length,
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -188,7 +194,7 @@ class GalacticaSharded(Galactica):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
|
|
|
@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{
|
{
|
||||||
|
|
|
@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
decoder_input_lengths = []
|
decoder_input_lengths = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
max_truncation = 0
|
||||||
padding_right_offset = 0
|
padding_right_offset = 0
|
||||||
for r in pb.requests:
|
for r in pb.requests:
|
||||||
inputs.append(r.inputs)
|
inputs.append(r.inputs)
|
||||||
|
@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
|
||||||
r.stopping_parameters, tokenizer
|
r.stopping_parameters, tokenizer
|
||||||
)
|
)
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
max_truncation = max(max_truncation, r.truncate)
|
||||||
padding_right_offset = max(
|
padding_right_offset = max(
|
||||||
padding_right_offset, stopping_criteria.max_new_tokens
|
padding_right_offset, stopping_criteria.max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_truncation,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
|
|
|
@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
|
|
Loading…
Reference in New Issue