Merge remote-tracking branch 'upstream/main' into fix_rocm_fa

This commit is contained in:
Mohit Sharma 2024-09-27 10:34:16 +00:00
commit 47c81d2924
45 changed files with 2627 additions and 2032 deletions

View File

@ -58,3 +58,6 @@ jobs:
- name: Run Rust tests
run: |
cargo test
- name: Run Rust tests with google feature
run: |
cargo test --features google

3
.gitignore vendored
View File

@ -3,9 +3,8 @@ target
router/tokenizer.json
*__pycache__*
backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files
*.hip

243
Cargo.lock generated
View File

@ -109,9 +109,9 @@ dependencies = [
[[package]]
name = "anyhow"
version = "1.0.88"
version = "1.0.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356"
checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6"
[[package]]
name = "arbitrary"
@ -177,9 +177,9 @@ dependencies = [
[[package]]
name = "async-trait"
version = "0.1.82"
version = "0.1.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1"
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
@ -269,9 +269,9 @@ dependencies = [
[[package]]
name = "aws-lc-sys"
version = "0.21.1"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234314bd569802ec87011d653d6815c6d7b9ffb969e9fee5b8b20ef860e8dce9"
checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62"
dependencies = [
"bindgen",
"cc",
@ -309,19 +309,19 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 0.1.2",
"tokio",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum"
version = "0.7.5"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
dependencies = [
"async-trait",
"axum-core 0.4.3",
"axum-core 0.4.4",
"bytes",
"futures-util",
"http 1.1.0",
@ -342,7 +342,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tower",
"tower 0.5.1",
"tower-layer",
"tower-service",
"tracing",
@ -367,9 +367,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.4.3"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3"
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
dependencies = [
"async-trait",
"bytes",
@ -380,7 +380,7 @@ dependencies = [
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 0.1.2",
"sync_wrapper 1.0.1",
"tower-layer",
"tower-service",
"tracing",
@ -392,13 +392,13 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
dependencies = [
"axum 0.7.5",
"axum 0.7.6",
"futures-core",
"futures-util",
"http 1.1.0",
"opentelemetry 0.21.0",
"pin-project-lite",
"tower",
"tower 0.4.13",
"tracing",
"tracing-opentelemetry 0.22.0",
"tracing-opentelemetry-instrumentation-sdk",
@ -546,9 +546,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]]
name = "bytes"
version = "1.7.1"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3"
[[package]]
name = "camino"
@ -605,9 +605,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.1.18"
version = "1.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
dependencies = [
"jobserver",
"libc",
@ -675,9 +675,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.17"
version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac"
checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3"
dependencies = [
"clap_builder",
"clap_derive",
@ -685,9 +685,9 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.5.17"
version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73"
checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b"
dependencies = [
"anstream",
"anstyle",
@ -697,9 +697,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.5.13"
version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck 0.5.0",
"proc-macro2",
@ -1732,9 +1732,9 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.8"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba"
checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b"
dependencies = [
"bytes",
"futures-channel",
@ -1745,7 +1745,6 @@ dependencies = [
"pin-project-lite",
"socket2",
"tokio",
"tower",
"tower-service",
"tracing",
]
@ -1985,7 +1984,7 @@ dependencies = [
"anyhow",
"base64 0.21.7",
"bytecount",
"clap 4.5.17",
"clap 4.5.18",
"fancy-regex",
"fraction",
"getrandom",
@ -2025,9 +2024,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]]
name = "libc"
version = "0.2.158"
version = "0.2.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
[[package]]
name = "libfuzzer-sys"
@ -2240,9 +2239,9 @@ dependencies = [
[[package]]
name = "minijinja"
version = "2.2.0"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad"
checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e"
dependencies = [
"serde",
"serde_json",
@ -2250,9 +2249,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
version = "2.2.0"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b"
checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a"
dependencies = [
"minijinja",
"serde",
@ -2600,9 +2599,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.20.0"
version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "onig"
@ -2955,9 +2954,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pkg-config"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
[[package]]
name = "plotters"
@ -3002,9 +3001,9 @@ dependencies = [
[[package]]
name = "portable-atomic"
version = "1.7.0"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
[[package]]
name = "powerfmt"
@ -3161,9 +3160,9 @@ dependencies = [
[[package]]
name = "pyo3"
version = "0.22.2"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225"
dependencies = [
"cfg-if",
"indoc",
@ -3179,9 +3178,9 @@ dependencies = [
[[package]]
name = "pyo3-build-config"
version = "0.22.2"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3"
dependencies = [
"once_cell",
"target-lexicon",
@ -3189,9 +3188,9 @@ dependencies = [
[[package]]
name = "pyo3-ffi"
version = "0.22.2"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c"
dependencies = [
"libc",
"pyo3-build-config",
@ -3199,9 +3198,9 @@ dependencies = [
[[package]]
name = "pyo3-macros"
version = "0.22.2"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
@ -3211,9 +3210,9 @@ dependencies = [
[[package]]
name = "pyo3-macros-backend"
version = "0.22.2"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1"
dependencies = [
"heck 0.5.0",
"proc-macro2",
@ -3403,9 +3402,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.4"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853"
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
dependencies = [
"bitflags 2.6.0",
]
@ -3770,9 +3769,9 @@ dependencies = [
[[package]]
name = "security-framework-sys"
version = "2.11.1"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6"
dependencies = [
"core-foundation-sys",
"libc",
@ -4175,11 +4174,11 @@ dependencies = [
[[package]]
name = "text-generation-backends-trtllm"
version = "2.2.1-dev0"
version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"clap 4.5.17",
"clap 4.5.18",
"cmake",
"cxx",
"cxx-build",
@ -4198,11 +4197,10 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "2.2.1-dev0"
version = "2.3.1-dev0"
dependencies = [
"average",
"clap 4.5.17",
"crossterm",
"clap 4.5.18",
"float-ord",
"hf-hub",
"ratatui",
@ -4219,7 +4217,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "2.2.1-dev0"
version = "2.3.1-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -4231,15 +4229,15 @@ dependencies = [
"tokio",
"tonic 0.10.2",
"tonic-build",
"tower",
"tower 0.4.13",
"tracing",
]
[[package]]
name = "text-generation-launcher"
version = "2.2.1-dev0"
version = "2.3.1-dev0"
dependencies = [
"clap 4.5.17",
"clap 4.5.18",
"ctrlc",
"float_eq",
"hf-hub",
@ -4256,14 +4254,14 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "2.2.1-dev0"
version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.5",
"axum 0.7.6",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.17",
"clap 4.5.18",
"csv",
"futures",
"futures-util",
@ -4304,15 +4302,64 @@ dependencies = [
]
[[package]]
name = "text-generation-router-v3"
version = "2.2.1-dev0"
name = "text-generation-router-v2"
version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.5",
"axum 0.7.6",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.17",
"clap 4.5.18",
"futures",
"futures-util",
"grpc-metadata",
"hf-hub",
"image",
"init-tracing-opentelemetry",
"jsonschema",
"metrics",
"metrics-exporter-prometheus",
"minijinja",
"minijinja-contrib",
"nohash-hasher",
"once_cell",
"opentelemetry 0.20.0",
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
"rand",
"regex",
"reqwest",
"serde",
"serde_json",
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers 0.20.0",
"tokio",
"tokio-stream",
"tonic 0.10.2",
"tonic-build",
"tower 0.4.13",
"tower-http",
"tracing",
"tracing-opentelemetry 0.21.0",
"tracing-subscriber",
"utoipa",
"utoipa-swagger-ui",
]
[[package]]
name = "text-generation-router-v3"
version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
"criterion",
"futures",
"futures-util",
@ -4345,7 +4392,7 @@ dependencies = [
"tokio-stream",
"tonic 0.10.2",
"tonic-build",
"tower",
"tower 0.4.13",
"tower-http",
"tracing",
"tracing-opentelemetry 0.21.0",
@ -4365,18 +4412,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.63"
version = "1.0.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
@ -4647,9 +4694,9 @@ dependencies = [
[[package]]
name = "toml_edit"
version = "0.22.20"
version = "0.22.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d"
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
dependencies = [
"indexmap 2.5.0",
"serde",
@ -4680,7 +4727,7 @@ dependencies = [
"prost 0.11.9",
"tokio",
"tokio-stream",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@ -4707,7 +4754,7 @@ dependencies = [
"prost 0.12.6",
"tokio",
"tokio-stream",
"tower",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@ -4746,6 +4793,22 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper 0.1.2",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-http"
version = "0.5.2"
@ -4959,9 +5022,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
[[package]]
name = "unicode-normalization"
version = "0.1.23"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956"
dependencies = [
"tinyvec",
]
@ -4994,9 +5057,9 @@ dependencies = [
[[package]]
name = "unicode-width"
version = "0.1.13"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]]
name = "unicode_categories"
@ -5096,7 +5159,7 @@ version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
dependencies = [
"axum 0.7.5",
"axum 0.7.6",
"mime_guess",
"regex",
"rust-embed",
@ -5313,9 +5376,9 @@ dependencies = [
[[package]]
name = "webpki-roots"
version = "0.26.5"
version = "0.26.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a"
checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958"
dependencies = [
"rustls-pki-types",
]
@ -5604,9 +5667,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.18"
version = "0.6.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f"
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
dependencies = [
"memchr",
]

View File

@ -1,26 +1,26 @@
[workspace]
members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
"backends/trtllm",
"backends/client",
"launcher",
"router"
]
default-members = [
"benchmark",
"backends/v2",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"backends/client",
"launcher",
"router"
]
resolver = "2"
[workspace.package]
version = "2.2.1-dev0"
version = "2.3.1-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -40,7 +40,6 @@ COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt
RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile

View File

@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id $model
```
And then you can make requests like

View File

@ -1,17 +0,0 @@
{
mkPoetryApplication,
pkg-config,
protobuf,
openssl,
}:
mkPoetryApplication {
# name = "text-generation-server";
projectDir = ./server;
# nativeBuildInputs = [ pkg-config ];
# buildInputs = [ openssl.dev protobuf ];
}

75
backends/v2/Cargo.toml Normal file
View File

@ -0,0 +1,75 @@
[package]
name = "text-generation-router-v2"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router-v2"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
slotmap = "1.0.7"
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = [
"rt",
"rt-multi-thread",
"parking_lot",
"signal",
"sync",
] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = [
"opentelemetry-otlp",
] }
minijinja = { workspace = true }
minijinja-contrib = { workspace = true }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]

19
backends/v2/build.rs Normal file
View File

@ -0,0 +1,19 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

506
backends/v2/src/backend.rs Normal file
View File

@ -0,0 +1,506 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV2 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV2 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV2 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).inspect_err(|_err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v2_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}

View File

@ -0,0 +1,257 @@
/// Single shard Client
use crate::client::pb;
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v2::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v2 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -0,0 +1,68 @@
//! Text Generation gRPC client library
use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod grpc_client;
mod sharded_client;
pub use grpc_client::Client;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -0,0 +1,252 @@
/// Multi shard Client
use crate::client::{ClientError, Result};
use crate::client::{Health, ShardInfo};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::InfoResponse;
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

141
backends/v2/src/lib.rs Normal file
View File

@ -0,0 +1,141 @@
mod backend;
mod client;
mod queue;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV2;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV2, BackendInfo), V2Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V2Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V2Error::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V2Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V2Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};
let backend = BackendV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V2Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

212
backends/v2/src/main.rs Normal file
View File

@ -0,0 +1,212 @@
use clap::{Parser, Subcommand};
use text_generation_router::{server, usage_stats};
use text_generation_router_v2::{connect_backend, V2Error};
use thiserror::Error;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
} = args;
if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi;
let api_doc = text_generation_router::server::ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
};
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
"`max_batch_size` must be > 0".to_string(),
));
}
}
let (backend, _backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await?;
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V2Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -1,14 +1,14 @@
use crate::infer::{InferError, InferStreamResponse};
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min;
use std::collections::VecDeque;
use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
};
use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -218,7 +218,7 @@ impl State {
// Create span for this batch to add context to inference calls
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
next_batch_span.follows_from(&Span::current());
next_batch_span.follows_from(Span::current());
let mut batch_requests = Vec::with_capacity(self.entries.len());
let mut batch_entries =
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tracing::info_span;
fn default_entry() -> (
@ -415,7 +416,9 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {

View File

@ -16,7 +16,6 @@ path = "src/main.rs"
[dependencies]
average = "0.14"
clap = { version = "4.4.5", features = ["derive", "env"] }
crossterm = "0.28.1"
float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
@ -25,7 +24,7 @@ text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
ratatui = { version = "0.28.1", default-features = false, features = ["crossterm"] }
ratatui = "0.28.1"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
hf-hub = { workspace = true }

View File

@ -7,7 +7,7 @@
</div>
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
and powered by [Ratatui](https://github.com/ratatui/ratatui).
## Install

View File

@ -1,6 +1,6 @@
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
use crate::generation::{Decode, Message, Prefill};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
use ratatui::style::{Color, Modifier, Style};
use ratatui::text::{Line, Span};

View File

@ -1,5 +1,5 @@
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
use crossterm::event;
use ratatui::crossterm::event;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};

View File

@ -6,8 +6,8 @@ mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use ratatui::backend::CrosstermBackend;
use ratatui::crossterm::ExecutableCommand;
use ratatui::Terminal;
use std::io;
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
@ -50,9 +50,9 @@ pub async fn run(
};
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
ratatui::crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
@ -128,9 +128,9 @@ pub async fn run(
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
ratatui::crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
let parameters_table = table::parameters_table(
tokenizer_name,

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.2.1-dev0"
"version": "2.3.1-dev0"
},
"paths": {
"/": {

View File

@ -36,7 +36,7 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
```
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
```bash
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
@ -72,6 +72,22 @@ curl 127.0.0.1:3000/generate \
}'
```
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
```json
curl 127.0.0.1:3000/generate \
-X POST \
-H 'Content-Type: application/json' \
-d '{
"inputs": "Hello who are you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "myadapter"
}
}'
```
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
An updated tutorial with detailed examples will be published soon. Stay tuned!

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \
ghcr.io/huggingface/text-generation-inference:2.3.0-rocm \
--model-id $model
```

View File

@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
--model-id $model --cuda-graphs 0
```
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu \
--model-id $model --cuda-graphs 0
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
ghcr.io/huggingface/text-generation-inference:2.3.0 \
--model-id $model
```

View File

@ -11,10 +11,19 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.2.0 \
ghcr.io/huggingface/text-generation-inference:2.3.0 \
--model-id $model
```
<Tip>
If you want to serve gated or private models, which provide
controlled access to sensitive or proprietary content, refer to
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
for detailed instructions.
</Tip>
### Supported hardware
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.

View File

@ -24,13 +24,13 @@
"tokens": [
{
"id": 1736,
"logprob": -2.03125,
"logprob": -2.109375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"logprob": -1.90625,
"special": false,
"text": "\n\n"
},
@ -42,48 +42,48 @@
},
{
"id": 2121,
"logprob": -1.8125,
"logprob": -1.796875,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.24121094,
"logprob": -0.24511719,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.100097656,
"logprob": -0.09326172,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9453125,
"logprob": -0.95703125,
"special": false,
"text": " is"
},
{
"id": 476,
"logprob": -1.703125,
"id": 1671,
"logprob": -1.5859375,
"special": false,
"text": " a"
"text": " used"
},
{
"id": 4551,
"logprob": -2.453125,
"id": 577,
"logprob": -0.39257812,
"special": false,
"text": " document"
"text": " to"
},
{
"id": 674,
"logprob": -0.796875,
"id": 3853,
"logprob": -1.25,
"special": false,
"text": " that"
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is a document that"
"generated_text": " form\n\nThe test request form is used to request"
}

View File

@ -11,12 +11,12 @@
},
{
"id": 2015,
"logprob": -9.640625,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.375,
"logprob": -10.3671875,
"text": " request"
}
],
@ -24,19 +24,19 @@
"tokens": [
{
"id": 604,
"logprob": -0.2824707,
"logprob": -0.28271484,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -0.19030762,
"logprob": -0.18493652,
"special": false,
"text": " the"
},
{
"id": 16819,
"logprob": -1.4892578,
"logprob": -1.4804688,
"special": false,
"text": " detection"
},
@ -46,44 +46,44 @@
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": -2.0195312,
"special": false,
"text": " the"
},
{
"id": 8566,
"logprob": 0.0,
"special": false,
"text": " presence"
},
{
"id": 689,
"logprob": -0.16491699,
"special": false,
"text": " or"
},
{
"id": 14862,
"logprob": 0.0,
"special": false,
"text": " absence"
},
{
"id": 576,
"logprob": -0.9946289,
"special": false,
"text": " of"
},
{
"id": 671,
"logprob": -0.5263672,
"logprob": -2.1738281,
"special": false,
"text": " an"
},
{
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " RNA"
},
{
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " virus"
},
{
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " in"
},
{
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " patients"
},
{
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " who"
}
],
"top_tokens": null
},
"generated_text": "Test request for the detection of the presence or absence of an"
"generated_text": "Test request for the detection of an RNA virus in patients who"
}

View File

@ -1038,6 +1038,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
Ok(log) => log.trace(),
// For interactive debugging ?
Err(_) => {
if LevelFilter::current() >= tracing::Level::DEBUG {
stdout.write_all(line).unwrap();
if lines.peek().is_some() {
stdout.write_all(b"\n").unwrap();
@ -1049,6 +1050,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
}
}
}
}
}
fn find_num_shards(

View File

@ -8,9 +8,11 @@ use crate::{
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
};
use async_stream::stream;
use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all;
use futures::Stream;
use minijinja::ErrorKind;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@ -87,7 +89,14 @@ impl Infer {
pub(crate) async fn generate_stream<'a>(
&'a self,
request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> {
) -> Result<
(
OwnedSemaphorePermit,
u32, // input_length
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
),
InferError,
> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
@ -107,9 +116,18 @@ impl Infer {
})?;
let input_length = valid_request.input_length;
let generation_stream = self.backend.schedule(valid_request)?;
let mut generation_stream = self.backend.schedule(valid_request)?;
Ok((permit, input_length, generation_stream))
// Wrap generation stream to update the backend health if the stream contains an error
let final_stream = stream! {
while let Some(response) = generation_stream.next().await {
yield response.inspect_err(|_err| {
self.backend_health.store(false, Ordering::SeqCst);
})
}
};
Ok((permit, input_length, final_stream))
}
/// Tokenizer the input
@ -278,13 +296,6 @@ impl Infer {
}
}
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);
#[derive(Debug)]
pub struct GeneratedText {
pub text: String,

View File

@ -1,4 +0,0 @@
mod queue;
mod scheduler;
pub(crate) use scheduler::BackendV2;

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,10 @@ mod kserve;
pub mod logging;
pub mod usage_stats;
mod vertex;
use crate::infer::{Infer, InferError};
use crate::server::prepare_chat_input;
use serde::{Deserialize, Serialize};
use tracing::warn;
use utoipa::ToSchema;
@ -54,32 +57,6 @@ impl std::str::FromStr for Attention {
}
}
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {
@ -174,6 +151,7 @@ impl HubProcessorConfig {
}
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(PartialEq))]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
/// A string that represents a [JSON Schema](https://json-schema.org/).
@ -230,6 +208,7 @@ pub struct Info {
}
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct GenerateParameters {
/// Generate best_of sequences and return the one if the highest token logprobs.
#[serde(default)]
@ -774,6 +753,7 @@ impl ChatCompletionChunk {
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
@ -890,7 +870,82 @@ pub(crate) struct ChatRequest {
pub stream_options: Option<StreamOptions>,
}
impl ChatRequest {
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, using_tools) = prepare_chat_input(
infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
Ok((
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
},
using_tools,
))
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
#[cfg_attr(test, derive(Debug, PartialEq))]
struct StreamOptions {
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
#[schema(example = "true")]
@ -984,6 +1039,7 @@ pub(crate) struct FunctionDefinition {
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
#[cfg_attr(test, derive(PartialEq))]
pub(crate) struct Tool {
// The type of the tool. Currently, only 'function' is supported.
#[schema(example = "function")]

View File

@ -8,7 +8,8 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -20,8 +21,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use crate::{ModelInfo, ModelsInfo};
@ -149,63 +149,11 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
)]
async fn get_chat_tokenize(
Extension(infer): Extension<Infer>,
Json(req): Json<ChatRequest>,
Json(chat): Json<ChatRequest>,
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
..
} = req;
let tool_prompt = tool_prompt.unwrap_or_default();
let (inputs, _grammar, _using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
let generate_request = GenerateRequest {
inputs,
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty: None,
frequency_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: true,
max_new_tokens: max_tokens,
return_full_text: None,
stop: stop.unwrap_or_default(),
truncate: None,
watermark: false,
details: false,
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: _grammar,
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
};
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
let input = generate_request.inputs.clone();
let encoding = infer.tokenize(generate_request).await?;
if let Some(encoding) = encoding {
@ -1162,77 +1110,20 @@ async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>,
Json(chat): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
logprobs,
max_tokens,
messages,
presence_penalty,
seed,
stop,
stream,
stream_options,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
logprobs,
..
} = req;
} = chat.clone();
let (generate_request, using_tools): (GenerateRequest, bool) =
chat.try_into_generate(&infer)?;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let logprobs = logprobs.unwrap_or(false);
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;
// build the request passing some parameters
let generate_request = GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty: req.frequency_penalty,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: req.top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
};
let logprobs = logprobs.unwrap_or_default();
// static values that will be returned in all cases
let model_id = info.model_id.clone();
@ -1385,186 +1276,6 @@ async fn chat_completions(
}
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
) {
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to prepare chat input: {}", e),
error_type: "Input preparation error".to_string(),
}),
));
}
};
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
/// Tokenize inputs
#[utoipa::path(
post,
@ -2318,7 +2029,8 @@ async fn start(
#[cfg(feature = "google")]
{
use crate::VertexInstance;
use crate::vertex::__path_vertex_compatibility;
use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};
#[derive(OpenApi)]
#[openapi(
@ -2637,7 +2349,7 @@ pub enum WebServerError {
type PreparedInput = (String, Option<GrammarType>, bool);
fn prepare_chat_input(
pub(crate) fn prepare_chat_input(
infer: &Infer,
response_format: Option<GrammarType>,
tools: Option<Vec<Tool>>,

360
router/src/vertex.rs Normal file
View File

@ -0,0 +1,360 @@
use crate::infer::Infer;
use crate::server::{generate_internal, ComputeType};
use crate::{
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
StreamOptions, Tool, ToolChoice,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexChat {
messages: Vec<Message>,
// Messages is ignored there.
#[serde(default)]
parameters: VertexParameters,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: Option<String>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
impl From<VertexChat> for ChatRequest {
fn from(val: VertexChat) -> Self {
Self {
messages: val.messages,
frequency_penalty: val.parameters.frequency_penalty,
guideline: val.parameters.guideline,
logit_bias: val.parameters.logit_bias,
logprobs: val.parameters.logprobs,
max_tokens: val.parameters.max_tokens,
model: val.parameters.model,
n: val.parameters.n,
presence_penalty: val.parameters.presence_penalty,
response_format: val.parameters.response_format,
seed: val.parameters.seed,
stop: val.parameters.stop,
stream_options: val.parameters.stream_options,
stream: val.parameters.stream,
temperature: val.parameters.temperature,
tool_choice: val.parameters.tool_choice,
tool_prompt: val.parameters.tool_prompt,
tools: val.parameters.tools,
top_logprobs: val.parameters.top_logprobs,
top_p: val.parameters.top_p,
}
}
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
#[serde(untagged)]
pub(crate) enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(VertexChat),
}
#[derive(Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
pub(crate) struct VertexResponse {
pub predictions: Vec<String>,
}
/// Generate tokens from Vertex request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Generated Text", body = VertexResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(
skip_all,
fields(
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
)]
pub(crate) async fn vertex_compatibility(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: "Input validation error".to_string(),
error_type: "Input validation error".to_string(),
}),
));
}
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.into_iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
do_sample: true,
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
seed: instance.parameters.as_ref().and_then(|p| p.seed),
details: true,
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let chat_request: ChatRequest = instance.into();
let (generate_request, _using_tools): (GenerateRequest, bool) =
chat_request.try_into_generate(&infer)?;
generate_request
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Message, MessageContent};
#[test]
fn vertex_deserialization() {
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"instances": [
{
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
}
]
});
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
assert_eq!(
request,
VertexRequest {
instances: vec![VertexInstance::Chat(VertexChat {
messages: vec![Message {
role: "user".to_string(),
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
name: None,
},],
parameters: VertexParameters {
max_tokens: Some(128),
top_p: Some(0.95),
temperature: Some(0.7),
..Default::default()
}
})]
}
);
}
}

View File

@ -152,11 +152,13 @@ def create_decode_state(
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)

View File

@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
).reshape(-1)
f"{prefix}.weight_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
scale = scale.reshape(-1).expand(w.shape[0])
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
shapes = [x.shape for x in w]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes)
]
scale = torch.cat(scale, dim=0).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
scale = (
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1)
.expand(w.shape[0])
)
return Fp8Weight(
weight=w,
weight_scale=scale,
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
# This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
)
@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module):
bias=self.bias,
)
return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0])

View File

@ -1,15 +1,181 @@
from typing import Optional
from typing import Optional, Protocol, runtime_checkable
import torch
import torch.nn as nn
from loguru import logger
from transformers.activations import ACT2FN
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete
# class inheritance is whacky.
@runtime_checkable
class MoELayer(Protocol):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
): ...
def forward(
self, x: torch.Tensor, *, gating_output: torch.Tensor
) -> torch.Tensor: ...
class DenseMoELayer(nn.Module):
"""
Layer for MoE that applies *all* experts to each tokens and then weights
their outputs based on the calculated routing. This layer is much slower
than `SparseMoELayer` and should only be used when no fused kernels are
available (e.g. for unsupported quantizers).
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
hidden_act: str = "silu",
):
super().__init__()
log_once(
logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer",
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.n_experts = n_experts
self.renormalize = renormalize
self.topk = topk
self.topk_group = topk_group
if "gelu" in hidden_act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh"
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
elif "silu" in hidden_act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[hidden_act]
self.gate_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{gate_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.up_proj = [
TensorParallelColumnLinear.load(
None,
prefix=f"{prefix}.{i}.{up_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.down_proj = [
TensorParallelRowLinear.load(
None,
prefix=f"{prefix}.{i}.{down_proj_name}",
weights=weights,
bias=False,
)
for i in range(self.n_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gating_output: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.n_expert_group is not None and self.topk_group is not None:
topk_weights, topk_ids = grouped_topk(
x,
gating_output,
self.topk,
renormalize=self.renormalize,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
x, gating_output, self.topk, self.renormalize
)
topk_weights = topk_weights.to(x.dtype)
weights = torch.zeros(
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
out = torch.zeros_like(x)
for i in range(self.n_experts):
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
h = self.down_proj[i](h, reduce=False)
out += h * weights[:, i].view(-1, 1)
return out
class SparseMoELayer(nn.Module):
"""

View File

@ -334,6 +334,7 @@ def get_model(
model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None)
compression_config = config_dict.get("compression_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
@ -344,6 +345,23 @@ def get_model(
quantize = "fp8"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
elif compression_config is not None:
# TODO: at some point we should probably fully parse the compression
# configuration to know which parameters are compressed.
config_groups = compression_config.get("config_groups")
if config_groups is not None:
for _, group in config_groups.items():
weights_config = group.get("weights")
if weights_config is not None:
if (
weights_config["type"] == "float"
and weights_config["num_bits"] == 8
):
log_master(
logger.info, "Auto selecting quantization method fp8"
)
quantize = "fp8"
break
if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]:
@ -768,7 +786,6 @@ def get_model(
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,

View File

@ -13,18 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm":
from text_generation_server.layers import grouped_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import grouped_topk
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
@ -34,18 +33,15 @@ from text_generation_server.layers import (
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
if SYSTEM == "rocm":
try:
@ -415,8 +411,14 @@ class DeepseekV2MLP(nn.Module):
)
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: DeepseekV2Config, weights):
class DeepseekV2MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV2Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
@ -428,7 +430,7 @@ class BlockSparseMoE(nn.Module):
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe_layer = SparseMoELayer(
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
@ -437,6 +439,7 @@ class BlockSparseMoE(nn.Module):
topk_group=config.topk_group,
weights=weights,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
@ -471,96 +474,6 @@ class BlockSparseMoE(nn.Module):
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
self.n_routed_experts = config.n_routed_experts
self.n_expert_group = config.n_group
self.topk_group = config.topk_group
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
#
# Seems like no one quantizes the gate.
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.experts = [
DeepseekV2MLP(
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
)
for i in range(self.n_routed_experts)
]
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV2MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
# gate_logits: (sequence_length, n_experts)
router_logits = self.gate(x)
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
self.top_k,
renormalize=self.norm_topk_prob,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
def moe_infer_gpu(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
):
weights = torch.zeros(
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
)
weights.scatter_(1, topk_ids, topk_weight)
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i, expert in enumerate(self.experts):
# Add expert output to out with masking
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
return out
class DeepseekV2Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
@ -577,10 +490,12 @@ class DeepseekV2Layer(nn.Module):
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_cls = (
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV2MLP(
prefix=f"{prefix}.mlp",

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma2Attention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
)
self.softcap = config.attn_logit_softcapping
self.query_key_value = load_attention(config, prefix, weights)
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
softcap=self.softcap,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = Gemma2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
slots,
seqlen,
max_s,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
# faster post attention rms norm
@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=layer_id % 2 == 0,
)
@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -19,37 +19,30 @@
# limitations under the License.
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.moe import SparseMoELayer
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
reshape_and_cache,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.weights import UnquantizedWeight
@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights):
class MixtralMoE(nn.Module):
def __init__(
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
):
super().__init__()
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.moe = SparseMoELayer(
self.moe = moe_layer_cls(
n_expert_group=None,
n_experts=config.num_local_experts,
prefix=f"{prefix}.experts",
@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
up_proj_name="w3",
down_proj_name="w2",
)
assert isinstance(self.moe, MoELayer)
self.process_group = weights.process_group
@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix, config: MixtralConfig, weights):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size // weights.process_group.size()
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
act = config.hidden_act
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
self.w1 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w3 = [
TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.w2 = [
TensorParallelRowLinear.load(
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
)
for i in range(self.num_experts)
]
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
all_probs,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
all_probs.scatter_(1, not_selected_experts, 0)
# Re-normalize
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.w3[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class MixtralLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
moe_layer_cls = (
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
)
self.moe = MixtralMoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps

View File

@ -75,7 +75,6 @@ def load_and_merge_adapters(
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_info) == 1:
adapter = next(iter(adapter_parameters.adapter_info))
return load_module_map(
@ -191,16 +190,15 @@ def load_module_map(
weight_names: Tuple[str],
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
adapter_filenames = (
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
if adapter_path
else hub._cached_adapter_weight_files(
else hub._cached_weight_files(
adapter_id, revision=revision, extension=".safetensors"
)
)

View File

@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_adapter_weight_files(
adapter_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(adapter_id, revision)
if not d:
return []
filenames = _adapter_weight_files_from_dir(d, extension)
return filenames
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "adapter" not in f
and "training" not in f
]
return filenames
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _adapter_config_files_from_dir(d: Path) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(".json") and "arguments" not in f and "args" not in f
]
return filenames
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]: