Merge remote-tracking branch 'upstream/main' into fix_rocm_fa
This commit is contained in:
commit
47c81d2924
|
@ -58,3 +58,6 @@ jobs:
|
||||||
- name: Run Rust tests
|
- name: Run Rust tests
|
||||||
run: |
|
run: |
|
||||||
cargo test
|
cargo test
|
||||||
|
- name: Run Rust tests with google feature
|
||||||
|
run: |
|
||||||
|
cargo test --features google
|
||||||
|
|
|
@ -3,9 +3,8 @@ target
|
||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
*__pycache__*
|
*__pycache__*
|
||||||
|
|
||||||
|
backends/v2/src/client/pb
|
||||||
backends/v3/src/client/pb
|
backends/v3/src/client/pb
|
||||||
backends/client/src/v2/pb
|
|
||||||
backends/client/src/v3/pb
|
|
||||||
|
|
||||||
# ROCm auto-generated files
|
# ROCm auto-generated files
|
||||||
*.hip
|
*.hip
|
||||||
|
|
|
@ -109,9 +109,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyhow"
|
name = "anyhow"
|
||||||
version = "1.0.88"
|
version = "1.0.89"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4e1496f8fb1fbf272686b8d37f523dab3e4a7443300055e74cdaa449f3114356"
|
checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "arbitrary"
|
name = "arbitrary"
|
||||||
|
@ -177,9 +177,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.82"
|
version = "0.1.83"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1"
|
checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -269,9 +269,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aws-lc-sys"
|
name = "aws-lc-sys"
|
||||||
version = "0.21.1"
|
version = "0.21.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "234314bd569802ec87011d653d6815c6d7b9ffb969e9fee5b8b20ef860e8dce9"
|
checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bindgen",
|
"bindgen",
|
||||||
"cc",
|
"cc",
|
||||||
|
@ -309,19 +309,19 @@ dependencies = [
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper 0.1.2",
|
"sync_wrapper 0.1.2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.7.5"
|
version = "0.7.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core 0.4.3",
|
"axum-core 0.4.4",
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
|
@ -342,7 +342,7 @@ dependencies = [
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper 1.0.1",
|
"sync_wrapper 1.0.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower 0.5.1",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -367,9 +367,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.4.3"
|
version = "0.4.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3"
|
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -380,7 +380,7 @@ dependencies = [
|
||||||
"mime",
|
"mime",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"rustversion",
|
"rustversion",
|
||||||
"sync_wrapper 0.1.2",
|
"sync_wrapper 1.0.1",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -392,13 +392,13 @@ version = "0.16.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
|
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.7.5",
|
"axum 0.7.6",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http 1.1.0",
|
"http 1.1.0",
|
||||||
"opentelemetry 0.21.0",
|
"opentelemetry 0.21.0",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.22.0",
|
"tracing-opentelemetry 0.22.0",
|
||||||
"tracing-opentelemetry-instrumentation-sdk",
|
"tracing-opentelemetry-instrumentation-sdk",
|
||||||
|
@ -546,9 +546,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytes"
|
name = "bytes"
|
||||||
version = "1.7.1"
|
version = "1.7.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
|
checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "camino"
|
name = "camino"
|
||||||
|
@ -605,9 +605,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.1.18"
|
version = "1.1.21"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
|
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"jobserver",
|
"jobserver",
|
||||||
"libc",
|
"libc",
|
||||||
|
@ -675,9 +675,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap"
|
name = "clap"
|
||||||
version = "4.5.17"
|
version = "4.5.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac"
|
checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap_builder",
|
"clap_builder",
|
||||||
"clap_derive",
|
"clap_derive",
|
||||||
|
@ -685,9 +685,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_builder"
|
name = "clap_builder"
|
||||||
version = "4.5.17"
|
version = "4.5.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73"
|
checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anstream",
|
"anstream",
|
||||||
"anstyle",
|
"anstyle",
|
||||||
|
@ -697,9 +697,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "clap_derive"
|
name = "clap_derive"
|
||||||
version = "4.5.13"
|
version = "4.5.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0"
|
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
@ -1732,9 +1732,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper-util"
|
name = "hyper-util"
|
||||||
version = "0.1.8"
|
version = "0.1.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba"
|
checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
|
@ -1745,7 +1745,6 @@ dependencies = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
@ -1985,7 +1984,7 @@ dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"base64 0.21.7",
|
"base64 0.21.7",
|
||||||
"bytecount",
|
"bytecount",
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
"fancy-regex",
|
"fancy-regex",
|
||||||
"fraction",
|
"fraction",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
|
@ -2025,9 +2024,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.158"
|
version = "0.2.159"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
|
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libfuzzer-sys"
|
name = "libfuzzer-sys"
|
||||||
|
@ -2240,9 +2239,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minijinja"
|
name = "minijinja"
|
||||||
version = "2.2.0"
|
version = "2.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad"
|
checksum = "1028b628753a7e1a88fc59c9ba4b02ecc3bc0bd3c7af23df667bc28df9b3310e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
@ -2250,9 +2249,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minijinja-contrib"
|
name = "minijinja-contrib"
|
||||||
version = "2.2.0"
|
version = "2.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b"
|
checksum = "39ffd46ee854be23604a20efd6c9655374fefbe4d44b949dc0f907305d92873a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"minijinja",
|
"minijinja",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -2600,9 +2599,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "once_cell"
|
name = "once_cell"
|
||||||
version = "1.20.0"
|
version = "1.19.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe"
|
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "onig"
|
name = "onig"
|
||||||
|
@ -2955,9 +2954,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pkg-config"
|
name = "pkg-config"
|
||||||
version = "0.3.30"
|
version = "0.3.31"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
|
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "plotters"
|
name = "plotters"
|
||||||
|
@ -3002,9 +3001,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "portable-atomic"
|
name = "portable-atomic"
|
||||||
version = "1.7.0"
|
version = "1.8.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265"
|
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "powerfmt"
|
name = "powerfmt"
|
||||||
|
@ -3161,9 +3160,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3"
|
name = "pyo3"
|
||||||
version = "0.22.2"
|
version = "0.22.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433"
|
checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
@ -3179,9 +3178,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-build-config"
|
name = "pyo3-build-config"
|
||||||
version = "0.22.2"
|
version = "0.22.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8"
|
checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"target-lexicon",
|
"target-lexicon",
|
||||||
|
@ -3189,9 +3188,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-ffi"
|
name = "pyo3-ffi"
|
||||||
version = "0.22.2"
|
version = "0.22.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6"
|
checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
|
@ -3199,9 +3198,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-macros"
|
name = "pyo3-macros"
|
||||||
version = "0.22.2"
|
version = "0.22.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206"
|
checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
|
@ -3211,9 +3210,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-macros-backend"
|
name = "pyo3-macros-backend"
|
||||||
version = "0.22.2"
|
version = "0.22.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372"
|
checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck 0.5.0",
|
"heck 0.5.0",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
@ -3403,9 +3402,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.4"
|
version = "0.5.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0884ad60e090bf1345b93da0a5de8923c93884cd03f40dfcfddd3b4bee661853"
|
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.6.0",
|
"bitflags 2.6.0",
|
||||||
]
|
]
|
||||||
|
@ -3770,9 +3769,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "security-framework-sys"
|
name = "security-framework-sys"
|
||||||
version = "2.11.1"
|
version = "2.12.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf"
|
checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"core-foundation-sys",
|
"core-foundation-sys",
|
||||||
"libc",
|
"libc",
|
||||||
|
@ -4175,11 +4174,11 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-backends-trtllm"
|
name = "text-generation-backends-trtllm"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
|
@ -4198,11 +4197,10 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-benchmark"
|
name = "text-generation-benchmark"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"average",
|
"average",
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
"crossterm",
|
|
||||||
"float-ord",
|
"float-ord",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
|
@ -4219,7 +4217,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-client"
|
name = "text-generation-client"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
|
@ -4231,15 +4229,15 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
"tonic-build",
|
"tonic-build",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-launcher"
|
name = "text-generation-launcher"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
"ctrlc",
|
"ctrlc",
|
||||||
"float_eq",
|
"float_eq",
|
||||||
"hf-hub",
|
"hf-hub",
|
||||||
|
@ -4256,14 +4254,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router"
|
name = "text-generation-router"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.5",
|
"axum 0.7.6",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
"csv",
|
"csv",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
@ -4304,15 +4302,64 @@ dependencies = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "text-generation-router-v3"
|
name = "text-generation-router-v2"
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.7.5",
|
"axum 0.7.6",
|
||||||
"axum-tracing-opentelemetry",
|
"axum-tracing-opentelemetry",
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"clap 4.5.17",
|
"clap 4.5.18",
|
||||||
|
"futures",
|
||||||
|
"futures-util",
|
||||||
|
"grpc-metadata",
|
||||||
|
"hf-hub",
|
||||||
|
"image",
|
||||||
|
"init-tracing-opentelemetry",
|
||||||
|
"jsonschema",
|
||||||
|
"metrics",
|
||||||
|
"metrics-exporter-prometheus",
|
||||||
|
"minijinja",
|
||||||
|
"minijinja-contrib",
|
||||||
|
"nohash-hasher",
|
||||||
|
"once_cell",
|
||||||
|
"opentelemetry 0.20.0",
|
||||||
|
"opentelemetry-otlp",
|
||||||
|
"prost 0.12.6",
|
||||||
|
"prost-build",
|
||||||
|
"rand",
|
||||||
|
"regex",
|
||||||
|
"reqwest",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"slotmap",
|
||||||
|
"text-generation-router",
|
||||||
|
"thiserror",
|
||||||
|
"tokenizers 0.20.0",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
|
"tonic 0.10.2",
|
||||||
|
"tonic-build",
|
||||||
|
"tower 0.4.13",
|
||||||
|
"tower-http",
|
||||||
|
"tracing",
|
||||||
|
"tracing-opentelemetry 0.21.0",
|
||||||
|
"tracing-subscriber",
|
||||||
|
"utoipa",
|
||||||
|
"utoipa-swagger-ui",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "text-generation-router-v3"
|
||||||
|
version = "2.3.1-dev0"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
|
"async-trait",
|
||||||
|
"axum 0.7.6",
|
||||||
|
"axum-tracing-opentelemetry",
|
||||||
|
"base64 0.22.1",
|
||||||
|
"clap 4.5.18",
|
||||||
"criterion",
|
"criterion",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
@ -4345,7 +4392,7 @@ dependencies = [
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
"tonic-build",
|
"tonic-build",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.21.0",
|
"tracing-opentelemetry 0.21.0",
|
||||||
|
@ -4365,18 +4412,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.63"
|
version = "1.0.64"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.63"
|
version = "1.0.64"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -4647,9 +4694,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml_edit"
|
name = "toml_edit"
|
||||||
version = "0.22.20"
|
version = "0.22.22"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d"
|
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"indexmap 2.5.0",
|
"indexmap 2.5.0",
|
||||||
"serde",
|
"serde",
|
||||||
|
@ -4680,7 +4727,7 @@ dependencies = [
|
||||||
"prost 0.11.9",
|
"prost 0.11.9",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -4707,7 +4754,7 @@ dependencies = [
|
||||||
"prost 0.12.6",
|
"prost 0.12.6",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower",
|
"tower 0.4.13",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -4746,6 +4793,22 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"futures-util",
|
||||||
|
"pin-project-lite",
|
||||||
|
"sync_wrapper 0.1.2",
|
||||||
|
"tokio",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-http"
|
name = "tower-http"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
|
@ -4959,9 +5022,9 @@ checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-normalization"
|
name = "unicode-normalization"
|
||||||
version = "0.1.23"
|
version = "0.1.24"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5"
|
checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"tinyvec",
|
"tinyvec",
|
||||||
]
|
]
|
||||||
|
@ -4994,9 +5057,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-width"
|
name = "unicode-width"
|
||||||
version = "0.1.13"
|
version = "0.1.14"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
|
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode_categories"
|
name = "unicode_categories"
|
||||||
|
@ -5096,7 +5159,7 @@ version = "6.0.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum 0.7.5",
|
"axum 0.7.6",
|
||||||
"mime_guess",
|
"mime_guess",
|
||||||
"regex",
|
"regex",
|
||||||
"rust-embed",
|
"rust-embed",
|
||||||
|
@ -5313,9 +5376,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "webpki-roots"
|
name = "webpki-roots"
|
||||||
version = "0.26.5"
|
version = "0.26.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0bd24728e5af82c6c4ec1b66ac4844bdf8156257fccda846ec58b42cd0cdbe6a"
|
checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
]
|
]
|
||||||
|
@ -5604,9 +5667,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winnow"
|
name = "winnow"
|
||||||
version = "0.6.18"
|
version = "0.6.19"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f"
|
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,26 +1,26 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"backends/trtllm",
|
"backends/trtllm",
|
||||||
"backends/client",
|
|
||||||
"launcher",
|
"launcher",
|
||||||
"router"
|
"router"
|
||||||
]
|
]
|
||||||
default-members = [
|
default-members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
"backends/v2",
|
||||||
"backends/v3",
|
"backends/v3",
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
# "backends/trtllm",
|
# "backends/trtllm",
|
||||||
"backends/client",
|
|
||||||
"launcher",
|
"launcher",
|
||||||
"router"
|
"router"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "2.2.1-dev0"
|
version = "2.3.1-dev0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Olivier Dehaene"]
|
authors = ["Olivier Dehaene"]
|
||||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
|
@ -40,7 +40,6 @@ COPY router router
|
||||||
COPY backends backends
|
COPY backends backends
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
RUN cargo build --profile release-opt
|
|
||||||
|
|
||||||
# Python builder
|
# Python builder
|
||||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||||
|
|
|
@ -83,7 +83,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
||||||
volume=$PWD/data
|
volume=$PWD/data
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 --model-id $model
|
ghcr.io/huggingface/text-generation-inference:2.3.0 --model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
And then you can make requests like
|
And then you can make requests like
|
||||||
|
|
17
_server.nix
17
_server.nix
|
@ -1,17 +0,0 @@
|
||||||
{
|
|
||||||
mkPoetryApplication,
|
|
||||||
pkg-config,
|
|
||||||
protobuf,
|
|
||||||
openssl,
|
|
||||||
}:
|
|
||||||
|
|
||||||
mkPoetryApplication {
|
|
||||||
# name = "text-generation-server";
|
|
||||||
|
|
||||||
projectDir = ./server;
|
|
||||||
|
|
||||||
# nativeBuildInputs = [ pkg-config ];
|
|
||||||
|
|
||||||
# buildInputs = [ openssl.dev protobuf ];
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
[package]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
description = "Text Generation Webserver"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-router-v2"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait = "0.1.74"
|
||||||
|
async-stream = "0.3.5"
|
||||||
|
axum = { version = "0.7", features = ["json"] }
|
||||||
|
axum-tracing-opentelemetry = "0.16"
|
||||||
|
text-generation-router = { path = "../../router" }
|
||||||
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
|
grpc-metadata = { path = "../grpc-metadata" }
|
||||||
|
futures = "0.3.28"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
|
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||||
|
metrics = { workspace = true }
|
||||||
|
metrics-exporter-prometheus = { workspace = true }
|
||||||
|
nohash-hasher = "0.2.0"
|
||||||
|
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||||
|
opentelemetry-otlp = "0.13.0"
|
||||||
|
rand = "0.8.5"
|
||||||
|
reqwest = { version = "0.11.20", features = [] }
|
||||||
|
serde = "1.0.188"
|
||||||
|
serde_json = "1.0.107"
|
||||||
|
slotmap = "1.0.7"
|
||||||
|
thiserror = "1.0.48"
|
||||||
|
tokenizers = { workspace = true }
|
||||||
|
tokio = { version = "1.32.0", features = [
|
||||||
|
"rt",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"parking_lot",
|
||||||
|
"signal",
|
||||||
|
"sync",
|
||||||
|
] }
|
||||||
|
tokio-stream = "0.1.14"
|
||||||
|
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-opentelemetry = "0.21.0"
|
||||||
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
|
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||||
|
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||||
|
init-tracing-opentelemetry = { version = "0.14.1", features = [
|
||||||
|
"opentelemetry-otlp",
|
||||||
|
] }
|
||||||
|
minijinja = { workspace = true }
|
||||||
|
minijinja-contrib = { workspace = true }
|
||||||
|
futures-util = "0.3.30"
|
||||||
|
regex = "1.10.3"
|
||||||
|
once_cell = "1.19.0"
|
||||||
|
image = "0.25.1"
|
||||||
|
base64 = { workspace = true }
|
||||||
|
prost = "^0.12"
|
||||||
|
tonic = "^0.10"
|
||||||
|
tower = "^0.4"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tonic-build = "0.10.1"
|
||||||
|
prost-build = "0.12.1"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ngrok"]
|
||||||
|
ngrok = ["text-generation-router/ngrok"]
|
||||||
|
google = ["text-generation-router/google"]
|
||||||
|
kserve = ["text-generation-router/kserve"]
|
|
@ -0,0 +1,19 @@
|
||||||
|
use std::fs;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
println!("cargo:rerun-if-changed=../../proto/");
|
||||||
|
|
||||||
|
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||||
|
let mut config = prost_build::Config::new();
|
||||||
|
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||||
|
|
||||||
|
tonic_build::configure()
|
||||||
|
.build_client(true)
|
||||||
|
.build_server(false)
|
||||||
|
.out_dir("src/client/pb")
|
||||||
|
.include_file("mod.rs")
|
||||||
|
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||||
|
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -0,0 +1,506 @@
|
||||||
|
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||||
|
/// Batching and inference logic
|
||||||
|
use crate::queue::{Entry, Queue};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use nohash_hasher::IntMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
|
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::{mpsc, Notify};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
|
pub struct BackendV2 {
|
||||||
|
/// Request queue
|
||||||
|
queue: Queue,
|
||||||
|
/// Notify batcher on queue appends
|
||||||
|
batching_task_notifier: Arc<Notify>,
|
||||||
|
/// Client clone, used for health checks to skip the queue
|
||||||
|
client: ShardedClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendV2 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
requires_padding: bool,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
speculate: u32,
|
||||||
|
) -> Self {
|
||||||
|
// Infer shared state
|
||||||
|
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
|
||||||
|
attention
|
||||||
|
.parse()
|
||||||
|
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
|
||||||
|
} else {
|
||||||
|
Attention::Paged
|
||||||
|
};
|
||||||
|
let block_size = if attention == Attention::FlashDecoding {
|
||||||
|
256
|
||||||
|
} else {
|
||||||
|
16
|
||||||
|
};
|
||||||
|
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
|
||||||
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
// Spawn batching background task that contains all the inference logic
|
||||||
|
tokio::spawn(batching_task(
|
||||||
|
client.clone(),
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
queue.clone(),
|
||||||
|
batching_task_notifier.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
queue,
|
||||||
|
batching_task_notifier,
|
||||||
|
client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for BackendV2 {
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
// MPSC channel to communicate with the background batching task
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Append the request to the queue
|
||||||
|
self.queue.append(Entry {
|
||||||
|
request,
|
||||||
|
response_tx,
|
||||||
|
span: Span::current(),
|
||||||
|
temp_span: None,
|
||||||
|
queue_time: Instant::now(),
|
||||||
|
batch_time: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
// to be batched
|
||||||
|
self.batching_task_notifier.notify_one();
|
||||||
|
|
||||||
|
// Return stream
|
||||||
|
Ok(UnboundedReceiverStream::new(response_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
if current_health {
|
||||||
|
// Generation is healthy, we only check that the shards can allocate on device
|
||||||
|
self.client.device_health().await
|
||||||
|
} else {
|
||||||
|
self.client.model_health().await
|
||||||
|
}
|
||||||
|
.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batching logic
|
||||||
|
/// Will be launched in a background Tokio task
|
||||||
|
///
|
||||||
|
/// Batches requests and sends them to the inference server
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn batching_task(
|
||||||
|
mut client: ShardedClient,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
queue: Queue,
|
||||||
|
notifier: Arc<Notify>,
|
||||||
|
) {
|
||||||
|
// Infinite loop
|
||||||
|
loop {
|
||||||
|
// Wait for a notification from the Infer struct
|
||||||
|
notifier.notified().await;
|
||||||
|
|
||||||
|
// Get the next batch from the queue
|
||||||
|
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||||
|
// waiting in the queue
|
||||||
|
while let Some((mut entries, batch, span)) = queue
|
||||||
|
.next_batch(
|
||||||
|
None,
|
||||||
|
max_batch_size,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
let mut waiting_tokens = 1;
|
||||||
|
|
||||||
|
// We loop until we do not receive any cached batch from the inference server (== until
|
||||||
|
// all requests have met their stopping criteria)
|
||||||
|
while let Some(batch) = cached_batch {
|
||||||
|
// Get current batch info
|
||||||
|
let batch_size = batch.size;
|
||||||
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let mut batches = vec![batch];
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
// Try to get a new batch
|
||||||
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
|
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
// Tracking metrics
|
||||||
|
if min_size.is_some() {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
|
.increment(1);
|
||||||
|
} else {
|
||||||
|
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||||
|
.increment(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to add the info that this entry is waiting
|
||||||
|
// because a new batch is being computed
|
||||||
|
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||||
|
// Add relationships
|
||||||
|
span.follows_from(&entry_waiting_span);
|
||||||
|
entry_waiting_span.follows_from(&span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_waiting_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
|
// Reset waiting counter
|
||||||
|
waiting_tokens = 1;
|
||||||
|
// Extend current batch with the new batch
|
||||||
|
if let Some(new_cached_batch) = new_cached_batch {
|
||||||
|
entries.extend(new_entries);
|
||||||
|
batches.push(new_cached_batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create span for this batch to add context to inference calls
|
||||||
|
let next_batch_size = entries.len();
|
||||||
|
let next_batch_span =
|
||||||
|
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||||
|
entries.iter_mut().for_each(|(_, entry)| {
|
||||||
|
// Create a new span to link the batch back to this entry
|
||||||
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
|
// Add relationships
|
||||||
|
next_batch_span.follows_from(&entry_batch_span);
|
||||||
|
entry_batch_span.follows_from(&next_batch_span);
|
||||||
|
// Update entry
|
||||||
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
});
|
||||||
|
|
||||||
|
cached_batch = decode(&mut client, batches, &mut entries)
|
||||||
|
.instrument(next_batch_span)
|
||||||
|
.await;
|
||||||
|
waiting_tokens += 1;
|
||||||
|
}
|
||||||
|
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||||
|
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn prefill(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batch: Batch,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_id = batch.id;
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||||
|
|
||||||
|
match client.prefill(batch).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
let _ = client.clear_cache(Some(batch_id)).await;
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn decode(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
entries: &mut IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||||
|
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||||
|
|
||||||
|
match client.decode(batches).await {
|
||||||
|
Ok((generations, next_batch, timings)) => {
|
||||||
|
let start_filtering_time = Instant::now();
|
||||||
|
// Send generated tokens and filter stopped entries
|
||||||
|
filter_send_generations(generations, entries);
|
||||||
|
|
||||||
|
// Filter next batch and remove requests that were stopped
|
||||||
|
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||||
|
|
||||||
|
if let Some(concat_duration) = timings.concat {
|
||||||
|
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||||
|
.record(concat_duration.as_secs_f64());
|
||||||
|
}
|
||||||
|
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||||
|
.record(timings.forward.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||||
|
.record(timings.decode.as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||||
|
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||||
|
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||||
|
.record(start_time.elapsed().as_secs_f64());
|
||||||
|
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||||
|
next_batch
|
||||||
|
}
|
||||||
|
// If we have an error, we discard the whole batch
|
||||||
|
Err(err) => {
|
||||||
|
for id in batch_ids {
|
||||||
|
let _ = client.clear_cache(Some(id)).await;
|
||||||
|
}
|
||||||
|
send_errors(err, entries);
|
||||||
|
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a `batch` and remove all requests not present in `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn filter_batch(
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
next_batch: Option<CachedBatch>,
|
||||||
|
entries: &IntMap<u64, Entry>,
|
||||||
|
) -> Option<CachedBatch> {
|
||||||
|
let mut batch = next_batch?;
|
||||||
|
|
||||||
|
// No need to filter
|
||||||
|
if batch.size as usize == entries.len() {
|
||||||
|
return Some(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = batch.id;
|
||||||
|
|
||||||
|
// Retain only requests that are still in entries
|
||||||
|
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||||
|
|
||||||
|
if batch.request_ids.is_empty() {
|
||||||
|
// All requests have been filtered out
|
||||||
|
// Next batch is now empty
|
||||||
|
// Clear it from the Python shards cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.clear_cache(Some(id)).await.unwrap();
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Filter Python shard cache
|
||||||
|
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||||
|
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||||
|
/// and filter entries
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
generations.into_iter().for_each(|generation| {
|
||||||
|
let id = generation.request_id;
|
||||||
|
// Get entry
|
||||||
|
// We can `expect` here as the request id should always be in the entries
|
||||||
|
let entry = entries
|
||||||
|
.get(&id)
|
||||||
|
.expect("ID not found in entries. This is a bug.");
|
||||||
|
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||||
|
// Send generation responses back to the infer task
|
||||||
|
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||||
|
// request and we need to stop generating hence why we unwrap_or(true)
|
||||||
|
let stopped = send_responses(generation, entry).inspect_err(|_err| {
|
||||||
|
tracing::error!("Entry response channel error.");
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
}).unwrap_or(true);
|
||||||
|
if stopped {
|
||||||
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send responses through the `entry` response channel
|
||||||
|
fn send_responses(
|
||||||
|
generation: Generation,
|
||||||
|
entry: &Entry,
|
||||||
|
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||||
|
// Return directly if the channel is disconnected
|
||||||
|
if entry.response_tx.is_closed() {
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stopped = false;
|
||||||
|
|
||||||
|
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||||
|
// Create Token objects
|
||||||
|
// We do that here instead of in the Python code as Rust for loops are faster
|
||||||
|
let prefill_tokens = prefill_tokens
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(prefill_tokens.logprobs)
|
||||||
|
.zip(prefill_tokens.texts)
|
||||||
|
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create last Token
|
||||||
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
|
let n = tokens_.ids.len();
|
||||||
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
|
let mut iterator = tokens_
|
||||||
|
.ids
|
||||||
|
.into_iter()
|
||||||
|
.zip(tokens_.logprobs)
|
||||||
|
.zip(tokens_.texts)
|
||||||
|
.zip(tokens_.is_special)
|
||||||
|
.enumerate()
|
||||||
|
.peekable();
|
||||||
|
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||||
|
let token = Token {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
};
|
||||||
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
|
top_tokens_
|
||||||
|
.ids
|
||||||
|
.iter()
|
||||||
|
.zip(top_tokens_.logprobs.iter())
|
||||||
|
.zip(top_tokens_.texts.iter())
|
||||||
|
.zip(top_tokens_.is_special.iter())
|
||||||
|
.map(|(((&id, &logprob), text), &special)| Token {
|
||||||
|
id,
|
||||||
|
text: text.to_string(),
|
||||||
|
logprob,
|
||||||
|
special,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
match (&generation.generated_text, iterator.peek()) {
|
||||||
|
(Some(generated_text), None) => {
|
||||||
|
// Generation has ended
|
||||||
|
stopped = true;
|
||||||
|
// Send message
|
||||||
|
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens,
|
||||||
|
generated_text: GeneratedText::from(generated_text.clone()),
|
||||||
|
queued: entry.queue_time,
|
||||||
|
start: entry.batch_time.unwrap(),
|
||||||
|
}))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Send message
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(stopped)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send errors to Infer for all `entries`
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||||
|
entries.drain().for_each(|(_, entry)| {
|
||||||
|
// Create and enter a span to link this function back to the entry
|
||||||
|
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||||
|
let err = InferError::GenerationError(error.to_string());
|
||||||
|
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||||
|
tracing::error!("{err}");
|
||||||
|
|
||||||
|
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||||
|
entry
|
||||||
|
.response_tx
|
||||||
|
.send(Err(err))
|
||||||
|
.unwrap_or(());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||||
|
fn from(value: crate::client::GeneratedText) -> Self {
|
||||||
|
let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||||
|
let finish_reason = match v2_finish_reason {
|
||||||
|
crate::client::FinishReason::Length => FinishReason::Length,
|
||||||
|
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||||
|
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
text: value.text,
|
||||||
|
generated_tokens: value.generated_tokens,
|
||||||
|
finish_reason,
|
||||||
|
seed: value.seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,257 @@
|
||||||
|
/// Single shard Client
|
||||||
|
use crate::client::pb;
|
||||||
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
|
use pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;
|
||||||
|
use pb::generate::v2::*;
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tonic::transport::{Channel, Uri};
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
/// Text Generation Inference gRPC client
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Client {
|
||||||
|
stub: TextGenerationServiceClient<Channel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Returns a client connected to the given url
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let channel = Channel::builder(uri).connect().await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||||
|
.unwrap()
|
||||||
|
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||||
|
tokio::net::UnixStream::connect(path.clone())
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
stub: TextGenerationServiceClient::new(channel),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a list of uris or unix sockets of all shards
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||||
|
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||||
|
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||||
|
ClientError::Connection("Server does not support v2 interface".to_string())
|
||||||
|
})?;
|
||||||
|
let urls = response
|
||||||
|
.into_inner()
|
||||||
|
.urls
|
||||||
|
.into_iter()
|
||||||
|
// Remove unix socket prefix
|
||||||
|
.map(|url| match url.strip_prefix("unix://") {
|
||||||
|
None => url,
|
||||||
|
Some(stripped_url) => stripped_url.to_string(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Ok(urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
|
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||||
|
let response = self.stub.info(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get model health
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||||
|
let response = self.stub.health(request).await?.into_inner();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||||
|
self.stub.clear_cache(request).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let request = tonic::Request::new(FilterBatchRequest {
|
||||||
|
batch_id,
|
||||||
|
request_ids,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||||
|
Ok(filtered_batch.batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let mut n_tokens = 0;
|
||||||
|
let mut requests = Vec::new();
|
||||||
|
// Create requests
|
||||||
|
while n_tokens < max_prefill_tokens {
|
||||||
|
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||||
|
|
||||||
|
let mut inputs = String::new();
|
||||||
|
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||||
|
if n_tokens == 0 {
|
||||||
|
// 1 request is enough to test vision heads.
|
||||||
|
// Sending images on other queries messes up easily with truncation.
|
||||||
|
inputs.push_str(&format!(
|
||||||
|
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
requests.push(Request {
|
||||||
|
id: 0,
|
||||||
|
inputs,
|
||||||
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
|
truncate,
|
||||||
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 0.9,
|
||||||
|
top_k: 10,
|
||||||
|
top_p: 0.9,
|
||||||
|
typical_p: 0.9,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.2,
|
||||||
|
frequency_penalty: 0.1,
|
||||||
|
watermark: true,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true,
|
||||||
|
}),
|
||||||
|
prefill_logprobs: true,
|
||||||
|
top_n_tokens: 20,
|
||||||
|
});
|
||||||
|
n_tokens += max_input_length;
|
||||||
|
|
||||||
|
// Check max_batch_size
|
||||||
|
if Some(requests.len()) == max_batch_size {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
size: requests.len() as u32,
|
||||||
|
requests,
|
||||||
|
max_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
batch: Some(batch),
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
})
|
||||||
|
.inject_context();
|
||||||
|
let response = self.stub.warmup(request).await?.into_inner();
|
||||||
|
Ok(response.max_supported_total_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||||
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||||
|
let response = self.stub.decode(request).await?.into_inner();
|
||||||
|
Ok((
|
||||||
|
response.generations,
|
||||||
|
response.batch,
|
||||||
|
DecodeTimings::new(
|
||||||
|
response.concat_ns,
|
||||||
|
response.forward_ns,
|
||||||
|
response.decode_ns,
|
||||||
|
response.total_ns,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PrefillTimings {
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefillTimings {
|
||||||
|
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DecodeTimings {
|
||||||
|
pub concat: Option<Duration>,
|
||||||
|
pub forward: Duration,
|
||||||
|
pub decode: Duration,
|
||||||
|
pub total: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecodeTimings {
|
||||||
|
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
concat: concat_ns.map(Duration::from_nanos),
|
||||||
|
forward: Duration::from_nanos(forward_ns),
|
||||||
|
decode: Duration::from_nanos(decode_ns),
|
||||||
|
total: Duration::from_nanos(total_ns),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
//! Text Generation gRPC client library
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tonic::transport;
|
||||||
|
use tonic::Status;
|
||||||
|
|
||||||
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
|
mod pb;
|
||||||
|
|
||||||
|
mod grpc_client;
|
||||||
|
mod sharded_client;
|
||||||
|
|
||||||
|
pub use grpc_client::Client;
|
||||||
|
pub use pb::generate::v2::{
|
||||||
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,
|
||||||
|
InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
pub use sharded_client::ShardedClient;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Health {
|
||||||
|
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||||
|
async fn device_health(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Check if a generate server is healthy by doing a forward pass.
|
||||||
|
/// EXPENSIVE
|
||||||
|
async fn model_health(&self) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ShardInfo {
|
||||||
|
pub requires_padding: bool,
|
||||||
|
pub dtype: String,
|
||||||
|
pub device_type: String,
|
||||||
|
pub window_size: Option<u32>,
|
||||||
|
pub speculate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug, Clone)]
|
||||||
|
pub enum ClientError {
|
||||||
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
Connection(String),
|
||||||
|
#[error("Server error: {0}")]
|
||||||
|
Generation(String),
|
||||||
|
#[error("Sharded results are empty")]
|
||||||
|
EmptyResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Status> for ClientError {
|
||||||
|
fn from(err: Status) -> Self {
|
||||||
|
let err = Self::Generation(err.message().to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<transport::Error> for ClientError {
|
||||||
|
fn from(err: transport::Error) -> Self {
|
||||||
|
let err = Self::Connection(err.to_string());
|
||||||
|
tracing::error!("{err}");
|
||||||
|
err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, ClientError>;
|
|
@ -0,0 +1,252 @@
|
||||||
|
/// Multi shard Client
|
||||||
|
use crate::client::{ClientError, Result};
|
||||||
|
use crate::client::{Health, ShardInfo};
|
||||||
|
|
||||||
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
|
use crate::client::InfoResponse;
|
||||||
|
use crate::client::{
|
||||||
|
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||||
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::future::join_all;
|
||||||
|
use tonic::transport::Uri;
|
||||||
|
use tracing::instrument;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Text Generation Inference gRPC multi client
|
||||||
|
pub struct ShardedClient {
|
||||||
|
clients: Vec<Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ShardedClient {
|
||||||
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
|
Self { clients }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
|
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||||
|
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||||
|
// Get all uris/unix sockets from the master client
|
||||||
|
let uris = master_client.service_discovery().await?;
|
||||||
|
let futures = uris.into_iter().map(Client::connect_uds);
|
||||||
|
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||||
|
Ok(Self::new(clients?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given uri
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||||
|
let master_client = Client::connect(uri).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a client connected to the given unix socket
|
||||||
|
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||||
|
let master_client = Client::connect_uds(path).await?;
|
||||||
|
Self::from_master_client(master_client).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the model info
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.info())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GRPC health check
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.health())
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the past generations cache
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| client.clear_cache(batch_id))
|
||||||
|
.collect();
|
||||||
|
join_all(futures).await.into_iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter a cached batch
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn filter_batch(
|
||||||
|
&mut self,
|
||||||
|
batch_id: u64,
|
||||||
|
request_ids: Vec<u64>,
|
||||||
|
) -> Result<Option<CachedBatch>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||||
|
.collect();
|
||||||
|
// all shards return the same message
|
||||||
|
join_all(futures).await.pop().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Warmup on a max size batch
|
||||||
|
///
|
||||||
|
/// Returns the maximum amount of tokens supported by the hardware
|
||||||
|
#[instrument(skip(self))]
|
||||||
|
pub async fn warmup(
|
||||||
|
&mut self,
|
||||||
|
max_input_length: u32,
|
||||||
|
max_prefill_tokens: u32,
|
||||||
|
max_total_tokens: u32,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<Option<u32>> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| {
|
||||||
|
Box::pin(client.warmup(
|
||||||
|
max_input_length,
|
||||||
|
max_prefill_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Take the minimum value
|
||||||
|
let results = join_all(futures)
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||||
|
Ok(results.into_iter().flatten().min())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given batch
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batch
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||||
|
pub async fn prefill(
|
||||||
|
&mut self,
|
||||||
|
batch: Batch,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate one token for each request in the given cached batches
|
||||||
|
///
|
||||||
|
/// Returns Generation for each request in batches
|
||||||
|
/// and the next cached batch
|
||||||
|
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||||
|
pub async fn decode(
|
||||||
|
&mut self,
|
||||||
|
batches: Vec<CachedBatch>,
|
||||||
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||||
|
let futures: Vec<_> = self
|
||||||
|
.clients
|
||||||
|
.iter_mut()
|
||||||
|
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||||
|
.collect();
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||||
|
join_all(futures).await.into_iter().collect();
|
||||||
|
let mut results = results?;
|
||||||
|
|
||||||
|
let (mut generations, next_batch, mut timings) =
|
||||||
|
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||||
|
|
||||||
|
// Merge generations from different model shards
|
||||||
|
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||||
|
generations.append(&mut shard_generations);
|
||||||
|
// Return the timings of the slowest shard
|
||||||
|
if shard_timings.total > timings.total {
|
||||||
|
timings = shard_timings;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((generations, next_batch, timings))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InfoResponse> for ShardInfo {
|
||||||
|
fn from(value: InfoResponse) -> Self {
|
||||||
|
Self {
|
||||||
|
requires_padding: value.requires_padding,
|
||||||
|
dtype: value.dtype,
|
||||||
|
device_type: value.device_type,
|
||||||
|
window_size: value.window_size,
|
||||||
|
speculate: value.speculate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Health for ShardedClient {
|
||||||
|
async fn device_health(&self) -> Result<()> {
|
||||||
|
self.clone().health().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_health(&self) -> Result<()> {
|
||||||
|
// Dummy batch of 1 token and 1 generated token
|
||||||
|
let liveness_request = Request {
|
||||||
|
id: u64::MAX,
|
||||||
|
inputs: "liveness".to_string(),
|
||||||
|
truncate: 10,
|
||||||
|
prefill_logprobs: false,
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
watermark: false,
|
||||||
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: 1,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: false,
|
||||||
|
}),
|
||||||
|
top_n_tokens: 0,
|
||||||
|
};
|
||||||
|
let batch = Batch {
|
||||||
|
id: u64::MAX,
|
||||||
|
requests: vec![liveness_request],
|
||||||
|
size: 1,
|
||||||
|
max_tokens: 2,
|
||||||
|
};
|
||||||
|
self.clone().prefill(batch).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,141 @@
|
||||||
|
mod backend;
|
||||||
|
mod client;
|
||||||
|
mod queue;
|
||||||
|
|
||||||
|
use crate::client::{ClientError, ShardedClient};
|
||||||
|
pub(crate) use backend::BackendV2;
|
||||||
|
use serde::Serialize;
|
||||||
|
use thiserror::Error;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||||
|
pub struct BackendInfo {
|
||||||
|
/// Mandatory
|
||||||
|
#[schema(example = "cuda")]
|
||||||
|
pub model_device_type: String,
|
||||||
|
#[schema(example = "torch.float16")]
|
||||||
|
pub model_dtype: String,
|
||||||
|
|
||||||
|
/// Backend parameters
|
||||||
|
#[schema(example = "1")]
|
||||||
|
pub speculate: usize,
|
||||||
|
#[schema(example = "1.2")]
|
||||||
|
pub waiting_served_ratio: f32,
|
||||||
|
#[schema(example = "32000")]
|
||||||
|
pub max_batch_total_tokens: u32,
|
||||||
|
#[schema(example = "20")]
|
||||||
|
pub max_waiting_tokens: usize,
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub max_batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn connect_backend(
|
||||||
|
max_input_tokens: usize,
|
||||||
|
max_total_tokens: usize,
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
) -> Result<(BackendV2, BackendInfo), V2Error> {
|
||||||
|
// Helper function
|
||||||
|
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||||
|
match max_supported_batch_total_tokens {
|
||||||
|
// Older models do not support automatic max-batch-total-tokens
|
||||||
|
None => {
|
||||||
|
let max_batch_total_tokens = max_batch_total_tokens
|
||||||
|
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||||
|
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||||
|
Ok(max_batch_total_tokens)
|
||||||
|
}
|
||||||
|
// Flash attention models return their max supported total tokens
|
||||||
|
Some(max_supported_batch_total_tokens) => {
|
||||||
|
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||||
|
if max_batch_total_tokens.is_some() {
|
||||||
|
tracing::warn!(
|
||||||
|
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||||
|
Attention models."
|
||||||
|
);
|
||||||
|
tracing::warn!(
|
||||||
|
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||||
|
return Err(V2Error::NotEnoughMemory(max_total_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(max_supported_batch_total_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Connection)?;
|
||||||
|
|
||||||
|
// server is running on v2
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Cache)?;
|
||||||
|
// Get info from the shard
|
||||||
|
let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;
|
||||||
|
|
||||||
|
// Warmup model
|
||||||
|
tracing::info!("Warming up model");
|
||||||
|
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||||
|
sharded_client
|
||||||
|
.warmup(
|
||||||
|
max_input_tokens as u32,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_total_tokens as u32,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(V2Error::Warmup)?,
|
||||||
|
)?;
|
||||||
|
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||||
|
|
||||||
|
let backend_info = BackendInfo {
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
model_device_type: shard_info.device_type.clone(),
|
||||||
|
model_dtype: shard_info.dtype.clone(),
|
||||||
|
speculate: shard_info.speculate as usize,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backend = BackendV2::new(
|
||||||
|
sharded_client,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
shard_info.requires_padding,
|
||||||
|
shard_info.window_size,
|
||||||
|
shard_info.speculate,
|
||||||
|
);
|
||||||
|
|
||||||
|
tracing::info!("Using backend V3");
|
||||||
|
|
||||||
|
Ok((backend, backend_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum V2Error {
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||||
|
NotEnoughMemory(usize),
|
||||||
|
}
|
|
@ -0,0 +1,212 @@
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use text_generation_router::{server, usage_stats};
|
||||||
|
use text_generation_router_v2::{connect_backend, V2Error};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Option<Commands>,
|
||||||
|
|
||||||
|
#[clap(default_value = "128", long, env)]
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
max_best_of: usize,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_stop_sequences: usize,
|
||||||
|
#[clap(default_value = "5", long, env)]
|
||||||
|
max_top_n_tokens: u32,
|
||||||
|
#[clap(default_value = "1024", long, env)]
|
||||||
|
max_input_tokens: usize,
|
||||||
|
#[clap(default_value = "2048", long, env)]
|
||||||
|
max_total_tokens: usize,
|
||||||
|
#[clap(default_value = "1.2", long, env)]
|
||||||
|
waiting_served_ratio: f32,
|
||||||
|
#[clap(default_value = "4096", long, env)]
|
||||||
|
max_batch_prefill_tokens: u32,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_total_tokens: Option<u32>,
|
||||||
|
#[clap(default_value = "20", long, env)]
|
||||||
|
max_waiting_tokens: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
#[clap(default_value = "0.0.0.0", long, env)]
|
||||||
|
hostname: String,
|
||||||
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
|
port: u16,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
tokenizer_config_path: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
revision: Option<String>,
|
||||||
|
#[clap(default_value = "2", long, env)]
|
||||||
|
validation_workers: usize,
|
||||||
|
#[clap(long, env)]
|
||||||
|
api_key: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
json_output: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||||
|
otlp_service_name: String,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_edge: Option<String>,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
messages_api_enabled: bool,
|
||||||
|
#[clap(long, env, default_value_t = false)]
|
||||||
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
|
#[clap(default_value = "on", long, env)]
|
||||||
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
PrintSchema,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), RouterError> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
command,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
master_shard_uds_path,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
json_output,
|
||||||
|
otlp_endpoint,
|
||||||
|
otlp_service_name,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if let Some(Commands::PrintSchema) = command {
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||||
|
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||||
|
println!("{}", api_doc);
|
||||||
|
std::process::exit(0);
|
||||||
|
};
|
||||||
|
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||||
|
|
||||||
|
// Validate args
|
||||||
|
if max_input_tokens >= max_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if validation_workers == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`validation_workers` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||||
|
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
|
if max_batch_size == 0 {
|
||||||
|
return Err(RouterError::ArgumentValidation(
|
||||||
|
"`max_batch_size` must be > 0".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (backend, _backend_info) = connect_backend(
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
master_shard_uds_path,
|
||||||
|
waiting_served_ratio,
|
||||||
|
max_batch_prefill_tokens,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
max_waiting_tokens,
|
||||||
|
max_batch_size,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
server::run(
|
||||||
|
backend,
|
||||||
|
max_concurrent_requests,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequences,
|
||||||
|
max_top_n_tokens,
|
||||||
|
max_input_tokens,
|
||||||
|
max_total_tokens,
|
||||||
|
validation_workers,
|
||||||
|
api_key,
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_config_path,
|
||||||
|
revision,
|
||||||
|
hostname,
|
||||||
|
port,
|
||||||
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_edge,
|
||||||
|
messages_api_enabled,
|
||||||
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
|
usage_stats,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Argument validation error: {0}")]
|
||||||
|
ArgumentValidation(String),
|
||||||
|
#[error("Backend failed: {0}")]
|
||||||
|
Backend(#[from] V2Error),
|
||||||
|
#[error("WebServer error: {0}")]
|
||||||
|
WebServer(#[from] server::WebServerError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
}
|
|
@ -1,14 +1,14 @@
|
||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::client::{
|
||||||
use crate::validation::{
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::v2::{
|
use text_generation_router::infer::InferError;
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
|
use text_generation_router::validation::{
|
||||||
|
ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use text_generation_client::ChunksToString;
|
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Span};
|
||||||
|
@ -218,7 +218,7 @@ impl State {
|
||||||
|
|
||||||
// Create span for this batch to add context to inference calls
|
// Create span for this batch to add context to inference calls
|
||||||
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty);
|
||||||
next_batch_span.follows_from(&Span::current());
|
next_batch_span.follows_from(Span::current());
|
||||||
|
|
||||||
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
let mut batch_requests = Vec::with_capacity(self.entries.len());
|
||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
|
@ -404,6 +404,7 @@ impl From<ValidStoppingParameters> for StoppingCriteriaParameters {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
|
@ -415,7 +416,9 @@ mod tests {
|
||||||
let entry = Entry {
|
let entry = Entry {
|
||||||
request: ValidGenerateRequest {
|
request: ValidGenerateRequest {
|
||||||
inputs: vec![],
|
inputs: vec![],
|
||||||
|
input_ids: Some(Arc::new(vec![])),
|
||||||
input_length: 0,
|
input_length: 0,
|
||||||
|
add_special_tokens: true,
|
||||||
truncate: 0,
|
truncate: 0,
|
||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
parameters: ValidParameters {
|
parameters: ValidParameters {
|
|
@ -16,7 +16,6 @@ path = "src/main.rs"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
average = "0.14"
|
average = "0.14"
|
||||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||||
crossterm = "0.28.1"
|
|
||||||
float-ord = "0.3.2"
|
float-ord = "0.3.2"
|
||||||
serde = {version = "1.0.188", features = ["derive"]}
|
serde = {version = "1.0.188", features = ["derive"]}
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
@ -25,7 +24,7 @@ text-generation-client = { path = "../backends/client" }
|
||||||
thiserror = "1.0.48"
|
thiserror = "1.0.48"
|
||||||
tokenizers = { workspace = true }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||||
ratatui = { version = "0.28.1", default-features = false, features = ["crossterm"] }
|
ratatui = "0.28.1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||||
hf-hub = { workspace = true }
|
hf-hub = { workspace = true }
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||||
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
and powered by [Ratatui](https://github.com/ratatui/ratatui).
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||||
use crate::generation::{Decode, Message, Prefill};
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
use ratatui::layout::{Alignment, Constraint, Direction, Layout};
|
||||||
use ratatui::style::{Color, Modifier, Style};
|
use ratatui::style::{Color, Modifier, Style};
|
||||||
use ratatui::text::{Line, Span};
|
use ratatui::text::{Line, Span};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||||
use crossterm::event;
|
use ratatui::crossterm::event;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,8 @@ mod utils;
|
||||||
|
|
||||||
use crate::app::App;
|
use crate::app::App;
|
||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
|
||||||
use ratatui::backend::CrosstermBackend;
|
use ratatui::backend::CrosstermBackend;
|
||||||
|
use ratatui::crossterm::ExecutableCommand;
|
||||||
use ratatui::Terminal;
|
use ratatui::Terminal;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
||||||
|
@ -50,9 +50,9 @@ pub async fn run(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize terminal properties
|
// Initialize terminal properties
|
||||||
crossterm::terminal::enable_raw_mode()?;
|
ratatui::crossterm::terminal::enable_raw_mode()?;
|
||||||
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;
|
||||||
io::stdout().execute(crossterm::cursor::Hide)?;
|
io::stdout().execute(ratatui::crossterm::cursor::Hide)?;
|
||||||
|
|
||||||
// Initialize terminal
|
// Initialize terminal
|
||||||
let mut terminal = {
|
let mut terminal = {
|
||||||
|
@ -128,9 +128,9 @@ pub async fn run(
|
||||||
let _ = shutdown_guard_receiver.recv().await;
|
let _ = shutdown_guard_receiver.recv().await;
|
||||||
|
|
||||||
// Revert terminal to original view
|
// Revert terminal to original view
|
||||||
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;
|
||||||
crossterm::terminal::disable_raw_mode()?;
|
ratatui::crossterm::terminal::disable_raw_mode()?;
|
||||||
io::stdout().execute(crossterm::cursor::Show)?;
|
io::stdout().execute(ratatui::crossterm::cursor::Show)?;
|
||||||
|
|
||||||
let parameters_table = table::parameters_table(
|
let parameters_table = table::parameters_table(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
"name": "Apache 2.0",
|
"name": "Apache 2.0",
|
||||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||||
},
|
},
|
||||||
"version": "2.2.1-dev0"
|
"version": "2.3.1-dev0"
|
||||||
},
|
},
|
||||||
"paths": {
|
"paths": {
|
||||||
"/": {
|
"/": {
|
||||||
|
|
|
@ -36,7 +36,7 @@ To use LoRA in TGI, when starting the server, you can specify the list of LoRA m
|
||||||
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
LORA_ADAPTERS=predibase/customer_support,predibase/dbpedia
|
||||||
```
|
```
|
||||||
|
|
||||||
additionally, you can specify the path to the LoRA models using the `LORA_ADAPTERS_PATH` environment variable. For example:
|
To use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `"parameters": {"adapter_id": "adapter-name"}"`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
LORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter
|
||||||
|
@ -72,6 +72,22 @@ curl 127.0.0.1:3000/generate \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:
|
||||||
|
|
||||||
|
```json
|
||||||
|
curl 127.0.0.1:3000/generate \
|
||||||
|
-X POST \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"inputs": "Hello who are you?",
|
||||||
|
"parameters": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"adapter_id": "myadapter"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.
|
||||||
|
|
||||||
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
An updated tutorial with detailed examples will be published soon. Stay tuned!
|
||||||
|
|
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-rocm \
|
ghcr.io/huggingface/text-generation-inference:2.3.0-rocm \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-xpu \
|
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-xpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
||||||
docker run --rm --privileged --cap-add=sys_nice \
|
docker run --rm --privileged --cap-add=sys_nice \
|
||||||
--device=/dev/dri \
|
--device=/dev/dri \
|
||||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0-intel-cpu \
|
ghcr.io/huggingface/text-generation-inference:2.3.0-intel-cpu \
|
||||||
--model-id $model --cuda-graphs 0
|
--model-id $model --cuda-graphs 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -11,10 +11,19 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
||||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
ghcr.io/huggingface/text-generation-inference:2.3.0 \
|
||||||
--model-id $model
|
--model-id $model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
If you want to serve gated or private models, which provide
|
||||||
|
controlled access to sensitive or proprietary content, refer to
|
||||||
|
[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)
|
||||||
|
for detailed instructions.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
### Supported hardware
|
### Supported hardware
|
||||||
|
|
||||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||||
|
|
|
@ -24,13 +24,13 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -2.03125,
|
"logprob": -2.109375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 109,
|
"id": 109,
|
||||||
"logprob": -1.8671875,
|
"logprob": -1.90625,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\n\n"
|
"text": "\n\n"
|
||||||
},
|
},
|
||||||
|
@ -42,48 +42,48 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2121,
|
"id": 2121,
|
||||||
"logprob": -1.8125,
|
"logprob": -1.796875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " test"
|
"text": " test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -0.24121094,
|
"logprob": -0.24511719,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1736,
|
"id": 1736,
|
||||||
"logprob": -0.100097656,
|
"logprob": -0.09326172,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " form"
|
"text": " form"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 603,
|
"id": 603,
|
||||||
"logprob": -0.9453125,
|
"logprob": -0.95703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " is"
|
"text": " is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 476,
|
"id": 1671,
|
||||||
"logprob": -1.703125,
|
"logprob": -1.5859375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": " used"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4551,
|
"id": 577,
|
||||||
"logprob": -2.453125,
|
"logprob": -0.39257812,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " document"
|
"text": " to"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 674,
|
"id": 3853,
|
||||||
"logprob": -0.796875,
|
"logprob": -1.25,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " that"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": " form\n\nThe test request form is a document that"
|
"generated_text": " form\n\nThe test request form is used to request"
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,12 +11,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2015,
|
"id": 2015,
|
||||||
"logprob": -9.640625,
|
"logprob": -9.6484375,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 3853,
|
"id": 3853,
|
||||||
"logprob": -10.375,
|
"logprob": -10.3671875,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,19 +24,19 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 604,
|
"id": 604,
|
||||||
"logprob": -0.2824707,
|
"logprob": -0.28271484,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " for"
|
"text": " for"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 573,
|
"id": 573,
|
||||||
"logprob": -0.19030762,
|
"logprob": -0.18493652,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 16819,
|
"id": 16819,
|
||||||
"logprob": -1.4892578,
|
"logprob": -1.4804688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " detection"
|
"text": " detection"
|
||||||
},
|
},
|
||||||
|
@ -46,44 +46,44 @@
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " of"
|
"text": " of"
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": 573,
|
|
||||||
"logprob": -2.0195312,
|
|
||||||
"special": false,
|
|
||||||
"text": " the"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 8566,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " presence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 689,
|
|
||||||
"logprob": -0.16491699,
|
|
||||||
"special": false,
|
|
||||||
"text": " or"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 14862,
|
|
||||||
"logprob": 0.0,
|
|
||||||
"special": false,
|
|
||||||
"text": " absence"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 576,
|
|
||||||
"logprob": -0.9946289,
|
|
||||||
"special": false,
|
|
||||||
"text": " of"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": 671,
|
"id": 671,
|
||||||
"logprob": -0.5263672,
|
"logprob": -2.1738281,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " an"
|
"text": " an"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 24646,
|
||||||
|
"logprob": -3.0449219,
|
||||||
|
"special": false,
|
||||||
|
"text": " RNA"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12369,
|
||||||
|
"logprob": -0.19299316,
|
||||||
|
"special": false,
|
||||||
|
"text": " virus"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 575,
|
||||||
|
"logprob": -0.10632324,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 6022,
|
||||||
|
"logprob": -0.98095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " patients"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1064,
|
||||||
|
"logprob": -1.3095703,
|
||||||
|
"special": false,
|
||||||
|
"text": " who"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request for the detection of the presence or absence of an"
|
"generated_text": "Test request for the detection of an RNA virus in patients who"
|
||||||
}
|
}
|
||||||
|
|
|
@ -1038,6 +1038,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
Ok(log) => log.trace(),
|
Ok(log) => log.trace(),
|
||||||
// For interactive debugging ?
|
// For interactive debugging ?
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
if LevelFilter::current() >= tracing::Level::DEBUG {
|
||||||
stdout.write_all(line).unwrap();
|
stdout.write_all(line).unwrap();
|
||||||
if lines.peek().is_some() {
|
if lines.peek().is_some() {
|
||||||
stdout.write_all(b"\n").unwrap();
|
stdout.write_all(b"\n").unwrap();
|
||||||
|
@ -1049,6 +1050,7 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_num_shards(
|
fn find_num_shards(
|
||||||
|
|
|
@ -8,9 +8,11 @@ use crate::{
|
||||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||||
Message, PrefillToken, Token,
|
Message, PrefillToken, Token,
|
||||||
};
|
};
|
||||||
|
use async_stream::stream;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chat_template::ChatTemplate;
|
use chat_template::ChatTemplate;
|
||||||
use futures::future::try_join_all;
|
use futures::future::try_join_all;
|
||||||
|
use futures::Stream;
|
||||||
use minijinja::ErrorKind;
|
use minijinja::ErrorKind;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
@ -87,7 +89,14 @@ impl Infer {
|
||||||
pub(crate) async fn generate_stream<'a>(
|
pub(crate) async fn generate_stream<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateStreamResponse, InferError> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
OwnedSemaphorePermit,
|
||||||
|
u32, // input_length
|
||||||
|
impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,
|
||||||
|
),
|
||||||
|
InferError,
|
||||||
|
> {
|
||||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||||
let permit = self
|
let permit = self
|
||||||
.clone()
|
.clone()
|
||||||
|
@ -107,9 +116,18 @@ impl Infer {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let input_length = valid_request.input_length;
|
let input_length = valid_request.input_length;
|
||||||
let generation_stream = self.backend.schedule(valid_request)?;
|
let mut generation_stream = self.backend.schedule(valid_request)?;
|
||||||
|
|
||||||
Ok((permit, input_length, generation_stream))
|
// Wrap generation stream to update the backend health if the stream contains an error
|
||||||
|
let final_stream = stream! {
|
||||||
|
while let Some(response) = generation_stream.next().await {
|
||||||
|
yield response.inspect_err(|_err| {
|
||||||
|
self.backend_health.store(false, Ordering::SeqCst);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((permit, input_length, final_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tokenizer the input
|
/// Tokenizer the input
|
||||||
|
@ -278,13 +296,6 @@ impl Infer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Type alias for generation responses
|
|
||||||
pub(crate) type GenerateStreamResponse = (
|
|
||||||
OwnedSemaphorePermit,
|
|
||||||
u32, // input_length
|
|
||||||
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
|
|
||||||
);
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GeneratedText {
|
pub struct GeneratedText {
|
||||||
pub text: String,
|
pub text: String,
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
mod queue;
|
|
||||||
mod scheduler;
|
|
||||||
|
|
||||||
pub(crate) use scheduler::BackendV2;
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,10 @@ mod kserve;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
|
|
||||||
pub mod usage_stats;
|
pub mod usage_stats;
|
||||||
|
mod vertex;
|
||||||
|
|
||||||
|
use crate::infer::{Infer, InferError};
|
||||||
|
use crate::server::prepare_chat_input;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
@ -54,32 +57,6 @@ impl std::str::FromStr for Attention {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
pub(crate) struct GenerateVertexInstance {
|
|
||||||
#[schema(example = "What is Deep Learning?")]
|
|
||||||
pub inputs: String,
|
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
|
||||||
pub parameters: Option<GenerateParameters>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum VertexInstance {
|
|
||||||
Generate(GenerateVertexInstance),
|
|
||||||
Chat(ChatRequest),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
|
||||||
pub(crate) struct VertexRequest {
|
|
||||||
#[serde(rename = "instances")]
|
|
||||||
pub instances: Vec<VertexInstance>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
|
||||||
pub(crate) struct VertexResponse {
|
|
||||||
pub predictions: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hub type
|
/// Hub type
|
||||||
#[derive(Clone, Debug, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub struct HubModelInfo {
|
pub struct HubModelInfo {
|
||||||
|
@ -174,6 +151,7 @@ impl HubProcessorConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
#[serde(tag = "type", content = "value")]
|
#[serde(tag = "type", content = "value")]
|
||||||
pub(crate) enum GrammarType {
|
pub(crate) enum GrammarType {
|
||||||
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
/// A string that represents a [JSON Schema](https://json-schema.org/).
|
||||||
|
@ -230,6 +208,7 @@ pub struct Info {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
#[derive(Clone, Debug, Deserialize, ToSchema, Default)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
pub(crate) struct GenerateParameters {
|
pub(crate) struct GenerateParameters {
|
||||||
/// Generate best_of sequences and return the one if the highest token logprobs.
|
/// Generate best_of sequences and return the one if the highest token logprobs.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
@ -774,6 +753,7 @@ impl ChatCompletionChunk {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq, Default))]
|
||||||
pub(crate) struct ChatRequest {
|
pub(crate) struct ChatRequest {
|
||||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
@ -890,7 +870,82 @@ pub(crate) struct ChatRequest {
|
||||||
pub stream_options: Option<StreamOptions>,
|
pub stream_options: Option<StreamOptions>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ChatRequest {
|
||||||
|
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
|
||||||
|
let ChatRequest {
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
messages,
|
||||||
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
|
response_format,
|
||||||
|
guideline,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_p,
|
||||||
|
top_logprobs,
|
||||||
|
..
|
||||||
|
} = self;
|
||||||
|
|
||||||
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
let (inputs, grammar, using_tools) = prepare_chat_input(
|
||||||
|
infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
&tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop,
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: top_logprobs,
|
||||||
|
grammar,
|
||||||
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
using_tools,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
struct StreamOptions {
|
struct StreamOptions {
|
||||||
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
#[schema(example = "true")]
|
#[schema(example = "true")]
|
||||||
|
@ -984,6 +1039,7 @@ pub(crate) struct FunctionDefinition {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(PartialEq))]
|
||||||
pub(crate) struct Tool {
|
pub(crate) struct Tool {
|
||||||
// The type of the tool. Currently, only 'function' is supported.
|
// The type of the tool. Currently, only 'function' is supported.
|
||||||
#[schema(example = "function")]
|
#[schema(example = "function")]
|
||||||
|
|
|
@ -8,7 +8,8 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
use crate::vertex::vertex_compatibility;
|
||||||
|
use crate::ChatTokenizeResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
|
@ -20,8 +21,7 @@ use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
|
||||||
VertexResponse,
|
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||||
use crate::{ModelInfo, ModelsInfo};
|
use crate::{ModelInfo, ModelsInfo};
|
||||||
|
@ -149,63 +149,11 @@ async fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {
|
||||||
)]
|
)]
|
||||||
async fn get_chat_tokenize(
|
async fn get_chat_tokenize(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
let ChatRequest {
|
let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
..
|
|
||||||
} = req;
|
|
||||||
|
|
||||||
let tool_prompt = tool_prompt.unwrap_or_default();
|
|
||||||
let (inputs, _grammar, _using_tools) = prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let generate_request = GenerateRequest {
|
|
||||||
inputs,
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
top_k: None,
|
|
||||||
top_p: None,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample: true,
|
|
||||||
max_new_tokens: max_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop: stop.unwrap_or_default(),
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: false,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: None,
|
|
||||||
grammar: _grammar,
|
|
||||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let input = generate_request.inputs.clone();
|
let input = generate_request.inputs.clone();
|
||||||
let encoding = infer.tokenize(generate_request).await?;
|
let encoding = infer.tokenize(generate_request).await?;
|
||||||
if let Some(encoding) = encoding {
|
if let Some(encoding) = encoding {
|
||||||
|
@ -1162,77 +1110,20 @@ async fn chat_completions(
|
||||||
Extension(infer): Extension<Infer>,
|
Extension(infer): Extension<Infer>,
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(chat): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
model,
|
|
||||||
logprobs,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
presence_penalty,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
stream,
|
||||||
stream_options,
|
stream_options,
|
||||||
tools,
|
logprobs,
|
||||||
tool_choice,
|
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
..
|
..
|
||||||
} = req;
|
} = chat.clone();
|
||||||
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
|
chat.try_into_generate(&infer)?;
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
|
||||||
let logprobs = logprobs.unwrap_or(false);
|
|
||||||
let tool_prompt = tool_prompt
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.unwrap_or_else(default_tool_prompt);
|
|
||||||
let stop = stop.unwrap_or_default();
|
|
||||||
// enable greedy only when temperature is 0
|
|
||||||
let (do_sample, temperature) = match temperature {
|
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
|
||||||
other => (true, other),
|
|
||||||
};
|
|
||||||
let (inputs, grammar, using_tools) = prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
|
||||||
let generate_request = GenerateRequest {
|
|
||||||
inputs: inputs.to_string(),
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty: req.frequency_penalty,
|
|
||||||
top_k: None,
|
|
||||||
top_p: req.top_p,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop,
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: req.top_logprobs,
|
|
||||||
grammar,
|
|
||||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
// static values that will be returned in all cases
|
// static values that will be returned in all cases
|
||||||
let model_id = info.model_id.clone();
|
let model_id = info.model_id.clone();
|
||||||
|
@ -1385,186 +1276,6 @@ async fn chat_completions(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate tokens from Vertex request
|
|
||||||
#[utoipa::path(
|
|
||||||
post,
|
|
||||||
tag = "Text Generation Inference",
|
|
||||||
path = "/vertex",
|
|
||||||
request_body = VertexRequest,
|
|
||||||
responses(
|
|
||||||
(status = 200, description = "Generated Text", body = VertexResponse),
|
|
||||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Request failed during generation"})),
|
|
||||||
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Model is overloaded"})),
|
|
||||||
(status = 422, description = "Input validation error", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Input validation error"})),
|
|
||||||
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
|
||||||
example = json ! ({"error": "Incomplete generation"})),
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
#[instrument(
|
|
||||||
skip_all,
|
|
||||||
fields(
|
|
||||||
total_time,
|
|
||||||
validation_time,
|
|
||||||
queue_time,
|
|
||||||
inference_time,
|
|
||||||
time_per_token,
|
|
||||||
seed,
|
|
||||||
)
|
|
||||||
)]
|
|
||||||
async fn vertex_compatibility(
|
|
||||||
Extension(infer): Extension<Infer>,
|
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
|
||||||
Json(req): Json<VertexRequest>,
|
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
|
||||||
let span = tracing::Span::current();
|
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
|
||||||
|
|
||||||
// check that theres at least one instance
|
|
||||||
if req.instances.is_empty() {
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Input validation error".to_string(),
|
|
||||||
error_type: "Input validation error".to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare futures for all instances
|
|
||||||
let mut futures = Vec::with_capacity(req.instances.len());
|
|
||||||
|
|
||||||
for instance in req.instances.iter() {
|
|
||||||
let generate_request = match instance {
|
|
||||||
VertexInstance::Generate(instance) => GenerateRequest {
|
|
||||||
inputs: instance.inputs.clone(),
|
|
||||||
add_special_tokens: true,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
do_sample: true,
|
|
||||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
|
||||||
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: true,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
VertexInstance::Chat(instance) => {
|
|
||||||
let ChatRequest {
|
|
||||||
model,
|
|
||||||
max_tokens,
|
|
||||||
messages,
|
|
||||||
seed,
|
|
||||||
stop,
|
|
||||||
stream,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
tool_prompt,
|
|
||||||
temperature,
|
|
||||||
response_format,
|
|
||||||
guideline,
|
|
||||||
presence_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
top_p,
|
|
||||||
top_logprobs,
|
|
||||||
..
|
|
||||||
} = instance.clone();
|
|
||||||
|
|
||||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
|
||||||
let max_new_tokens = max_tokens.or(Some(100));
|
|
||||||
let tool_prompt = tool_prompt
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.unwrap_or_else(default_tool_prompt);
|
|
||||||
let stop = stop.unwrap_or_default();
|
|
||||||
// enable greedy only when temperature is 0
|
|
||||||
let (do_sample, temperature) = match temperature {
|
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
|
||||||
other => (true, other),
|
|
||||||
};
|
|
||||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
|
||||||
&infer,
|
|
||||||
response_format,
|
|
||||||
tools,
|
|
||||||
tool_choice,
|
|
||||||
&tool_prompt,
|
|
||||||
guideline,
|
|
||||||
messages,
|
|
||||||
) {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(e) => {
|
|
||||||
return Err((
|
|
||||||
StatusCode::BAD_REQUEST,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: format!("Failed to prepare chat input: {}", e),
|
|
||||||
error_type: "Input preparation error".to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
GenerateRequest {
|
|
||||||
inputs: inputs.to_string(),
|
|
||||||
add_special_tokens: false,
|
|
||||||
parameters: GenerateParameters {
|
|
||||||
best_of: None,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
top_k: None,
|
|
||||||
top_p,
|
|
||||||
typical_p: None,
|
|
||||||
do_sample,
|
|
||||||
max_new_tokens,
|
|
||||||
return_full_text: None,
|
|
||||||
stop,
|
|
||||||
truncate: None,
|
|
||||||
watermark: false,
|
|
||||||
details: true,
|
|
||||||
decoder_input_details: !stream,
|
|
||||||
seed,
|
|
||||||
top_n_tokens: top_logprobs,
|
|
||||||
grammar,
|
|
||||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let infer_clone = infer.clone();
|
|
||||||
let compute_type_clone = compute_type.clone();
|
|
||||||
let span_clone = span.clone();
|
|
||||||
|
|
||||||
futures.push(async move {
|
|
||||||
generate_internal(
|
|
||||||
Extension(infer_clone),
|
|
||||||
compute_type_clone,
|
|
||||||
Json(generate_request),
|
|
||||||
span_clone,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
|
||||||
.map_err(|_| {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: "Incomplete generation".into(),
|
|
||||||
error_type: "Incomplete generation".into(),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
|
||||||
let results = futures::future::join_all(futures).await;
|
|
||||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
|
||||||
let predictions = predictions?;
|
|
||||||
|
|
||||||
let response = VertexResponse { predictions };
|
|
||||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Tokenize inputs
|
/// Tokenize inputs
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
post,
|
post,
|
||||||
|
@ -2318,7 +2029,8 @@ async fn start(
|
||||||
|
|
||||||
#[cfg(feature = "google")]
|
#[cfg(feature = "google")]
|
||||||
{
|
{
|
||||||
use crate::VertexInstance;
|
use crate::vertex::__path_vertex_compatibility;
|
||||||
|
use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};
|
||||||
|
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
|
@ -2637,7 +2349,7 @@ pub enum WebServerError {
|
||||||
|
|
||||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||||
|
|
||||||
fn prepare_chat_input(
|
pub(crate) fn prepare_chat_input(
|
||||||
infer: &Infer,
|
infer: &Infer,
|
||||||
response_format: Option<GrammarType>,
|
response_format: Option<GrammarType>,
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
|
|
|
@ -0,0 +1,360 @@
|
||||||
|
use crate::infer::Infer;
|
||||||
|
use crate::server::{generate_internal, ComputeType};
|
||||||
|
use crate::{
|
||||||
|
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
|
||||||
|
StreamOptions, Tool, ToolChoice,
|
||||||
|
};
|
||||||
|
use axum::extract::Extension;
|
||||||
|
use axum::http::{HeaderMap, StatusCode};
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::Json;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::instrument;
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct GenerateVertexInstance {
|
||||||
|
#[schema(example = "What is Deep Learning?")]
|
||||||
|
pub inputs: String,
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexChat {
|
||||||
|
messages: Vec<Message>,
|
||||||
|
// Messages is ignored there.
|
||||||
|
#[serde(default)]
|
||||||
|
parameters: VertexParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexParameters {
|
||||||
|
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||||
|
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||||
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
|
||||||
|
/// decreasing the model's likelihood to repeat the same line verbatim.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "1.0")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
|
||||||
|
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
|
||||||
|
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
|
||||||
|
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
|
||||||
|
/// result in a ban or exclusive selection of the relevant token.
|
||||||
|
#[serde(default)]
|
||||||
|
pub logit_bias: Option<Vec<f32>>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
|
||||||
|
/// output token returned in the content of message.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub logprobs: Option<bool>,
|
||||||
|
|
||||||
|
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
|
||||||
|
/// an associated log probability. logprobs must be set to true if this parameter is used.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "5")]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(example = "32")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// UNUSED
|
||||||
|
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
|
||||||
|
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "2")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
|
||||||
|
/// increasing the model's likelihood to talk about new topics
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.1)]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[serde(default = "bool::default")]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[schema(nullable = true, example = 42)]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
|
||||||
|
/// lower values like 0.2 will make it more focused and deterministic.
|
||||||
|
///
|
||||||
|
/// We generally recommend altering this or `top_p` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 1.0)]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
|
||||||
|
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = 0.95)]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
|
||||||
|
/// functions the model may generate JSON inputs for.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// A prompt to be appended before the tools
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(
|
||||||
|
nullable = true,
|
||||||
|
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
|
||||||
|
)]
|
||||||
|
pub tool_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub tool_choice: ToolChoice,
|
||||||
|
|
||||||
|
/// Response format constraints for the generation.
|
||||||
|
///
|
||||||
|
/// NOTE: A request can use `response_format` OR `tools` but not both.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub response_format: Option<GrammarType>,
|
||||||
|
|
||||||
|
/// A guideline to be used in the chat_template
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<VertexChat> for ChatRequest {
|
||||||
|
fn from(val: VertexChat) -> Self {
|
||||||
|
Self {
|
||||||
|
messages: val.messages,
|
||||||
|
frequency_penalty: val.parameters.frequency_penalty,
|
||||||
|
guideline: val.parameters.guideline,
|
||||||
|
logit_bias: val.parameters.logit_bias,
|
||||||
|
logprobs: val.parameters.logprobs,
|
||||||
|
max_tokens: val.parameters.max_tokens,
|
||||||
|
model: val.parameters.model,
|
||||||
|
n: val.parameters.n,
|
||||||
|
presence_penalty: val.parameters.presence_penalty,
|
||||||
|
response_format: val.parameters.response_format,
|
||||||
|
seed: val.parameters.seed,
|
||||||
|
stop: val.parameters.stop,
|
||||||
|
stream_options: val.parameters.stream_options,
|
||||||
|
stream: val.parameters.stream,
|
||||||
|
temperature: val.parameters.temperature,
|
||||||
|
tool_choice: val.parameters.tool_choice,
|
||||||
|
tool_prompt: val.parameters.tool_prompt,
|
||||||
|
tools: val.parameters.tools,
|
||||||
|
top_logprobs: val.parameters.top_logprobs,
|
||||||
|
top_p: val.parameters.top_p,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub(crate) enum VertexInstance {
|
||||||
|
Generate(GenerateVertexInstance),
|
||||||
|
Chat(VertexChat),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize, ToSchema)]
|
||||||
|
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||||
|
pub(crate) struct VertexRequest {
|
||||||
|
#[serde(rename = "instances")]
|
||||||
|
pub instances: Vec<VertexInstance>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
pub(crate) struct VertexResponse {
|
||||||
|
pub predictions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate tokens from Vertex request
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/vertex",
|
||||||
|
request_body = VertexRequest,
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Generated Text", body = VertexResponse),
|
||||||
|
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Request failed during generation"})),
|
||||||
|
(status = 429, description = "Model is overloaded", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Model is overloaded"})),
|
||||||
|
(status = 422, description = "Input validation error", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Input validation error"})),
|
||||||
|
(status = 500, description = "Incomplete generation", body = ErrorResponse,
|
||||||
|
example = json ! ({"error": "Incomplete generation"})),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
#[instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
total_time,
|
||||||
|
validation_time,
|
||||||
|
queue_time,
|
||||||
|
inference_time,
|
||||||
|
time_per_token,
|
||||||
|
seed,
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub(crate) async fn vertex_compatibility(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
|
Json(req): Json<VertexRequest>,
|
||||||
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
|
// check that theres at least one instance
|
||||||
|
if req.instances.is_empty() {
|
||||||
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Input validation error".to_string(),
|
||||||
|
error_type: "Input validation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare futures for all instances
|
||||||
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
|
|
||||||
|
for instance in req.instances.into_iter() {
|
||||||
|
let generate_request = match instance {
|
||||||
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
|
inputs: instance.inputs.clone(),
|
||||||
|
add_special_tokens: true,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
seed: instance.parameters.as_ref().and_then(|p| p.seed),
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: true,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
VertexInstance::Chat(instance) => {
|
||||||
|
let chat_request: ChatRequest = instance.into();
|
||||||
|
let (generate_request, _using_tools): (GenerateRequest, bool) =
|
||||||
|
chat_request.try_into_generate(&infer)?;
|
||||||
|
generate_request
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
futures.push(async move {
|
||||||
|
generate_internal(
|
||||||
|
Extension(infer_clone),
|
||||||
|
compute_type_clone,
|
||||||
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Incomplete generation".into(),
|
||||||
|
error_type: "Incomplete generation".into(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||||
|
let results = futures::future::join_all(futures).await;
|
||||||
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||||
|
let predictions = predictions?;
|
||||||
|
|
||||||
|
let response = VertexResponse { predictions };
|
||||||
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{Message, MessageContent};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vertex_deserialization() {
|
||||||
|
let string = serde_json::json!({
|
||||||
|
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 128,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
|
||||||
|
let string = serde_json::json!({
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
});
|
||||||
|
|
||||||
|
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
|
||||||
|
let string = serde_json::json!({
|
||||||
|
|
||||||
|
"instances": [
|
||||||
|
{
|
||||||
|
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
|
||||||
|
"parameters": {
|
||||||
|
"max_tokens": 128,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
});
|
||||||
|
let request: VertexRequest = serde_json::from_value(string).expect("Can deserialize");
|
||||||
|
assert_eq!(
|
||||||
|
request,
|
||||||
|
VertexRequest {
|
||||||
|
instances: vec![VertexInstance::Chat(VertexChat {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
|
||||||
|
name: None,
|
||||||
|
},],
|
||||||
|
parameters: VertexParameters {
|
||||||
|
max_tokens: Some(128),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
temperature: Some(0.7),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})]
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -152,11 +152,13 @@ def create_decode_state(
|
||||||
):
|
):
|
||||||
"""Create a decode state."""
|
"""Create a decode state."""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
|
||||||
therefore stored as part of the state.
|
therefore stored as part of the state.
|
||||||
"""
|
"""
|
||||||
workspace_buffer = get_workspace(device)
|
workspace_buffer = get_workspace(device)
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
workspace_buffer,
|
workspace_buffer,
|
||||||
kv_layout="NHD",
|
kv_layout="NHD",
|
||||||
|
@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
|
||||||
paged_kv_indices_buffer=block_tables,
|
paged_kv_indices_buffer=block_tables,
|
||||||
paged_kv_indptr_buffer=block_tables_ptr,
|
paged_kv_indptr_buffer=block_tables_ptr,
|
||||||
paged_kv_last_page_len_buffer=last_page_len,
|
paged_kv_last_page_len_buffer=last_page_len,
|
||||||
use_tensor_cores=num_heads // num_kv_heads > 4,
|
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||||
|
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -87,9 +87,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -113,9 +115,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
if scale.numel() > 1:
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
f"{prefix}.weight_scale",
|
||||||
).reshape(-1)
|
dim=0,
|
||||||
|
block_sizes=block_sizes,
|
||||||
|
to_dtype=False,
|
||||||
|
)
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -132,16 +141,19 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
w = [
|
w = [
|
||||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
]
|
]
|
||||||
|
shapes = [x.shape for x in w]
|
||||||
|
|
||||||
# Concat then send to the device
|
# Concat then send to the device
|
||||||
w = torch.cat(w, dim=dim).to(weights.device)
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = [
|
scale = [
|
||||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
for p in prefixes
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -157,9 +169,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = weights.get_tensor(
|
scale = (
|
||||||
f"{prefix}.weight_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.expand(w.shape[0])
|
||||||
|
)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -182,6 +196,9 @@ class Fp8Weight(Weight):
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
|
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||||
|
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||||
|
self.weight_scale = self.weight_scale.contiguous()
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||||
)
|
)
|
||||||
|
@ -222,6 +239,9 @@ class Fp8Linear(torch.nn.Module):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||||
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
|
# fbgemm needs float32 scales.
|
||||||
|
scale = scale.float()
|
||||||
return cls(
|
return cls(
|
||||||
qweight=weight,
|
qweight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
@ -256,3 +276,10 @@ class Fp8Linear(torch.nn.Module):
|
||||||
bias=self.bias,
|
bias=self.bias,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
if scale.numel() > 1:
|
||||||
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
|
|
@ -1,15 +1,181 @@
|
||||||
from typing import Optional
|
from typing import Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from loguru import logger
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.utils.log import log_once
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
DefaultWeightsLoader,
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
from .fused_moe_rocm import grouped_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||||
|
elif SYSTEM != "ipex":
|
||||||
|
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||||
|
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||||
|
# class inheritance is whacky.
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class MoELayer(Protocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
): ...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, *, gating_output: torch.Tensor
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMoELayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer for MoE that applies *all* experts to each tokens and then weights
|
||||||
|
their outputs based on the calculated routing. This layer is much slower
|
||||||
|
than `SparseMoELayer` and should only be used when no fused kernels are
|
||||||
|
available (e.g. for unsupported quantizers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
n_expert_group: Optional[int],
|
||||||
|
n_experts: int,
|
||||||
|
prefix: str,
|
||||||
|
renormalize: bool,
|
||||||
|
topk: int,
|
||||||
|
topk_group: Optional[int],
|
||||||
|
weights: Weights,
|
||||||
|
gate_proj_name: str = "gate_proj",
|
||||||
|
up_proj_name: str = "up_proj",
|
||||||
|
down_proj_name: str = "down_proj",
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
log_once(
|
||||||
|
logger.info,
|
||||||
|
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (n_expert_group is None) == (
|
||||||
|
topk_group is None
|
||||||
|
), "n_expert_group and topk_group must both be None or have some value"
|
||||||
|
|
||||||
|
self.n_expert_group = n_expert_group
|
||||||
|
self.n_experts = n_experts
|
||||||
|
self.renormalize = renormalize
|
||||||
|
self.topk = topk
|
||||||
|
self.topk_group = topk_group
|
||||||
|
|
||||||
|
if "gelu" in hidden_act:
|
||||||
|
self.act = lambda x: torch.nn.functional.gelu(
|
||||||
|
x,
|
||||||
|
approximate=(
|
||||||
|
"tanh"
|
||||||
|
if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif "silu" in hidden_act:
|
||||||
|
self.act = torch.nn.functional.silu
|
||||||
|
else:
|
||||||
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
|
self.gate_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{gate_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.up_proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{up_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
self.down_proj = [
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
None,
|
||||||
|
prefix=f"{prefix}.{i}.{down_proj_name}",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
for i in range(self.n_experts)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (sequence_length, model_dim)
|
||||||
|
gating_output: (sequence_length, n_experts)
|
||||||
|
"""
|
||||||
|
# optional reshape
|
||||||
|
input_shape = x.shape
|
||||||
|
x = x.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if self.n_expert_group is not None and self.topk_group is not None:
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
x,
|
||||||
|
gating_output,
|
||||||
|
self.topk,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
num_expert_group=self.n_expert_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
x, gating_output, self.topk, self.renormalize
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
|
weights = torch.zeros(
|
||||||
|
topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))
|
||||||
|
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
for i in range(self.n_experts):
|
||||||
|
h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
|
||||||
|
h = self.down_proj[i](h, reduce=False)
|
||||||
|
out += h * weights[:, i].view(-1, 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SparseMoELayer(nn.Module):
|
class SparseMoELayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -334,6 +334,7 @@ def get_model(
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
compression_config = config_dict.get("compression_config", None)
|
||||||
if quantization_config is not None and quantize is None:
|
if quantization_config is not None and quantize is None:
|
||||||
method = quantization_config.get("quant_method", None)
|
method = quantization_config.get("quant_method", None)
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
|
@ -344,6 +345,23 @@ def get_model(
|
||||||
quantize = "fp8"
|
quantize = "fp8"
|
||||||
else:
|
else:
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
elif compression_config is not None:
|
||||||
|
# TODO: at some point we should probably fully parse the compression
|
||||||
|
# configuration to know which parameters are compressed.
|
||||||
|
config_groups = compression_config.get("config_groups")
|
||||||
|
if config_groups is not None:
|
||||||
|
for _, group in config_groups.items():
|
||||||
|
weights_config = group.get("weights")
|
||||||
|
if weights_config is not None:
|
||||||
|
if (
|
||||||
|
weights_config["type"] == "float"
|
||||||
|
and weights_config["num_bits"] == 8
|
||||||
|
):
|
||||||
|
log_master(
|
||||||
|
logger.info, "Auto selecting quantization method fp8"
|
||||||
|
)
|
||||||
|
quantize = "fp8"
|
||||||
|
break
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
|
@ -768,7 +786,6 @@ def get_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||||
print(f">>> model_type: {model_type}")
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashCausalLM(
|
return FlashCausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
|
@ -13,18 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
from text_generation_server.layers import grouped_topk
|
|
||||||
elif SYSTEM != "ipex":
|
|
||||||
from moe_kernels.fused_moe import grouped_topk
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
|
@ -34,18 +33,15 @@ from text_generation_server.layers import (
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
Seqlen,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
Seqlen,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
from text_generation_server.utils.weights import Weights
|
from text_generation_server.utils.weights import Weights
|
||||||
from torch import nn
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
|
@ -415,8 +411,14 @@ class DeepseekV2MLP(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
def __init__(self, prefix, config: DeepseekV2Config, weights):
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefix,
|
||||||
|
config: DeepseekV2Config,
|
||||||
|
moe_layer_cls: Type[MoELayer],
|
||||||
|
weights,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_dim = config.hidden_size
|
self.hidden_dim = config.hidden_size
|
||||||
|
@ -428,7 +430,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
# Gating
|
# Gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
self.moe_layer = SparseMoELayer(
|
self.moe_layer = moe_layer_cls(
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
n_experts=config.n_routed_experts,
|
n_experts=config.n_routed_experts,
|
||||||
n_expert_group=config.n_group,
|
n_expert_group=config.n_group,
|
||||||
|
@ -437,6 +439,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
assert isinstance(self.moe_layer, MoELayer)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
self.shared_experts = DeepseekV2MLP(
|
self.shared_experts = DeepseekV2MLP(
|
||||||
|
@ -471,96 +474,6 @@ class BlockSparseMoE(nn.Module):
|
||||||
return out.view(*x.shape)
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
|
||||||
def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.moe_intermediate_size = config.moe_intermediate_size
|
|
||||||
self.n_routed_experts = config.n_routed_experts
|
|
||||||
self.n_expert_group = config.n_group
|
|
||||||
self.topk_group = config.topk_group
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
self.norm_topk_prob = config.norm_topk_prob
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
|
||||||
|
|
||||||
# Gating
|
|
||||||
#
|
|
||||||
# Seems like no one quantizes the gate.
|
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
|
||||||
|
|
||||||
self.experts = [
|
|
||||||
DeepseekV2MLP(
|
|
||||||
f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size
|
|
||||||
)
|
|
||||||
for i in range(self.n_routed_experts)
|
|
||||||
]
|
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
|
||||||
self.shared_experts = DeepseekV2MLP(
|
|
||||||
prefix=f"{prefix}.shared_experts",
|
|
||||||
config=config,
|
|
||||||
weights=weights,
|
|
||||||
intermediate_size=config.moe_intermediate_size
|
|
||||||
* config.n_shared_experts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.shared_experts = None
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
if self.shared_experts is not None:
|
|
||||||
shared_output = self.shared_experts(x, reduce=False)
|
|
||||||
else:
|
|
||||||
shared_output = None
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
router_logits = self.gate(x)
|
|
||||||
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
x,
|
|
||||||
router_logits,
|
|
||||||
self.top_k,
|
|
||||||
renormalize=self.norm_topk_prob,
|
|
||||||
num_expert_group=self.n_expert_group,
|
|
||||||
topk_group=self.topk_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor
|
|
||||||
|
|
||||||
if shared_output is not None:
|
|
||||||
out = out + shared_output
|
|
||||||
|
|
||||||
# Reduce sum
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def moe_infer_gpu(
|
|
||||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
|
||||||
):
|
|
||||||
weights = torch.zeros(
|
|
||||||
topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
weights.scatter_(1, topk_ids, topk_weight)
|
|
||||||
|
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
|
||||||
for i, expert in enumerate(self.experts):
|
|
||||||
# Add expert output to out with masking
|
|
||||||
out += expert(x, reduce=False) * weights[:, i].view(-1, 1)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Layer(nn.Module):
|
class DeepseekV2Layer(nn.Module):
|
||||||
def __init__(self, prefix, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -577,10 +490,12 @@ class DeepseekV2Layer(nn.Module):
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
and layer_id % config.moe_layer_freq == 0
|
and layer_id % config.moe_layer_freq == 0
|
||||||
):
|
):
|
||||||
moe_cls = (
|
moe_layer_cls = (
|
||||||
BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
SparseMoELayer
|
||||||
|
if SparseMoELayer.is_supported(weights)
|
||||||
|
else DenseMoELayer
|
||||||
)
|
)
|
||||||
self.mlp = moe_cls(f"{prefix}.mlp", config, weights)
|
self.mlp = DeepseekV2MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||||
else:
|
else:
|
||||||
self.mlp = DeepseekV2MLP(
|
self.mlp = DeepseekV2MLP(
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
|
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Attention(torch.nn.Module):
|
class FlashGemma2Attention(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
|
@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.softcap = config.attn_logit_softcapping
|
self.softcap = config.attn_logit_softcapping
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
self.head_size * config.num_attention_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
|
@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Gemma2MLP(nn.Module):
|
class Gemma2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_activation
|
act = config.hidden_activation
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Layer(nn.Module):
|
class FlashGemma2Layer(nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemma2Attention(
|
self.self_attn = FlashGemma2Attention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
is_sliding=is_sliding,
|
is_sliding=is_sliding,
|
||||||
)
|
)
|
||||||
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = Gemma2MLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = Gemma2FastRMSNorm.load(
|
self.input_layernorm = Gemma2FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
res = normed_attn_res_output
|
res = normed_attn_res_output
|
||||||
|
|
||||||
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||||
mlp_output = self.mlp(pre_normed)
|
mlp_output = self.mlp(pre_normed, adapter_data)
|
||||||
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||||
|
|
||||||
return post_hidden_states, normed_attn_res_output
|
return post_hidden_states, normed_attn_res_output
|
||||||
|
@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
prefix=f"{prefix}.layers.{layer_id}",
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
is_sliding=layer_id % 2 == 0,
|
is_sliding=layer_id % 2 == 0,
|
||||||
)
|
)
|
||||||
|
@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
|
@ -19,37 +19,30 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
|
||||||
paged_attention,
|
|
||||||
attention,
|
|
||||||
reshape_and_cache,
|
|
||||||
Seqlen,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.moe import SparseMoELayer
|
from text_generation_server.layers.attention import (
|
||||||
from text_generation_server.layers.layernorm import (
|
Seqlen,
|
||||||
FastRMSNorm,
|
attention,
|
||||||
)
|
paged_attention,
|
||||||
from text_generation_server.layers.rotary import (
|
reshape_and_cache,
|
||||||
PositionRotaryEmbedding,
|
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,14 +308,16 @@ def round_up(x: torch.Tensor, value: int):
|
||||||
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
|
||||||
|
|
||||||
|
|
||||||
class BlockSparseMoE(nn.Module):
|
class MixtralMoE(nn.Module):
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
def __init__(
|
||||||
|
self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# gating
|
# gating
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||||
|
|
||||||
self.moe = SparseMoELayer(
|
self.moe = moe_layer_cls(
|
||||||
n_expert_group=None,
|
n_expert_group=None,
|
||||||
n_experts=config.num_local_experts,
|
n_experts=config.num_local_experts,
|
||||||
prefix=f"{prefix}.experts",
|
prefix=f"{prefix}.experts",
|
||||||
|
@ -334,6 +329,7 @@ class BlockSparseMoE(nn.Module):
|
||||||
up_proj_name="w3",
|
up_proj_name="w3",
|
||||||
down_proj_name="w2",
|
down_proj_name="w2",
|
||||||
)
|
)
|
||||||
|
assert isinstance(self.moe, MoELayer)
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
self.process_group = weights.process_group
|
||||||
|
|
||||||
|
@ -349,95 +345,6 @@ class BlockSparseMoE(nn.Module):
|
||||||
return out.view(*x.shape)
|
return out.view(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
class DenseMoE(nn.Module):
|
|
||||||
def __init__(self, prefix, config: MixtralConfig, weights):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_dim = config.hidden_size
|
|
||||||
self.ffn_dim = config.intermediate_size // weights.process_group.size()
|
|
||||||
self.num_experts = config.num_local_experts
|
|
||||||
self.top_k = config.num_experts_per_tok
|
|
||||||
|
|
||||||
act = config.hidden_act
|
|
||||||
if "gelu" in act:
|
|
||||||
self.act = lambda x: torch.nn.functional.gelu(
|
|
||||||
x,
|
|
||||||
approximate=(
|
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif "silu" in act:
|
|
||||||
self.act = torch.nn.functional.silu
|
|
||||||
else:
|
|
||||||
self.act = ACT2FN[act]
|
|
||||||
|
|
||||||
# gating
|
|
||||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
|
||||||
|
|
||||||
self.w1 = [
|
|
||||||
TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w1", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
self.w3 = [
|
|
||||||
TensorParallelColumnLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w3", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
self.w2 = [
|
|
||||||
TensorParallelRowLinear.load(
|
|
||||||
config, prefix=f"{prefix}.experts.{i}.w2", weights=weights, bias=False
|
|
||||||
)
|
|
||||||
for i in range(self.num_experts)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.process_group = weights.process_group
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (sequence_length, model_dim)
|
|
||||||
gate_logits: (sequence_length, n_experts)
|
|
||||||
"""
|
|
||||||
# optional reshape
|
|
||||||
input_shape = x.shape
|
|
||||||
x = x.view(-1, input_shape[-1])
|
|
||||||
|
|
||||||
# gate_logits: (sequence_length, n_experts)
|
|
||||||
gate_logits = self.gate(x)
|
|
||||||
# all_probs: (sequence_length, n_experts) and upcast for softmax
|
|
||||||
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
|
|
||||||
|
|
||||||
if self.top_k < self.num_experts:
|
|
||||||
_, not_selected_experts = torch.topk(
|
|
||||||
all_probs,
|
|
||||||
self.num_experts - self.top_k,
|
|
||||||
largest=False,
|
|
||||||
sorted=False,
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
# Mask not selected experts
|
|
||||||
all_probs.scatter_(1, not_selected_experts, 0)
|
|
||||||
|
|
||||||
# Re-normalize
|
|
||||||
weights = all_probs / all_probs.sum(dim=1, keepdim=True)
|
|
||||||
weights = weights.to(x.dtype)
|
|
||||||
|
|
||||||
# Final output tensor
|
|
||||||
out = x.new_zeros(x.shape[0], self.hidden_dim)
|
|
||||||
for i in range(self.num_experts):
|
|
||||||
h = self.act(self.w1[i](x)) * self.w3[i](x)
|
|
||||||
h = self.w2[i](h, reduce=False)
|
|
||||||
# Add expert output to out with masking
|
|
||||||
out += h * weights[:, i].view(-1, 1)
|
|
||||||
|
|
||||||
# Reduce sum
|
|
||||||
if self.process_group.size() > 1:
|
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, prefix: str, layer_id, config, weights):
|
def __init__(self, prefix: str, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -447,8 +354,12 @@ class MixtralLayer(nn.Module):
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_cls = BlockSparseMoE if SparseMoELayer.is_supported(weights) else DenseMoE
|
moe_layer_cls = (
|
||||||
self.moe = moe_cls(f"{prefix}.block_sparse_moe", config, weights)
|
SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer
|
||||||
|
)
|
||||||
|
self.moe = MixtralMoE(
|
||||||
|
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
|
|
@ -75,7 +75,6 @@ def load_and_merge_adapters(
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
|
||||||
if len(adapter_parameters.adapter_info) == 1:
|
if len(adapter_parameters.adapter_info) == 1:
|
||||||
adapter = next(iter(adapter_parameters.adapter_info))
|
adapter = next(iter(adapter_parameters.adapter_info))
|
||||||
return load_module_map(
|
return load_module_map(
|
||||||
|
@ -191,16 +190,15 @@ def load_module_map(
|
||||||
weight_names: Tuple[str],
|
weight_names: Tuple[str],
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
|
||||||
|
|
||||||
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
adapter_config = LoraConfig.load(adapter_path or adapter_id, None)
|
||||||
|
|
||||||
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
if not adapter_path and adapter_config.base_model_name_or_path != model_id:
|
||||||
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
|
||||||
|
|
||||||
adapter_filenames = (
|
adapter_filenames = (
|
||||||
hub._adapter_weight_files_from_dir(adapter_path, extension=".safetensors")
|
hub._weight_files_from_dir(adapter_path, extension=".safetensors")
|
||||||
if adapter_path
|
if adapter_path
|
||||||
else hub._cached_adapter_weight_files(
|
else hub._cached_weight_files(
|
||||||
adapter_id, revision=revision, extension=".safetensors"
|
adapter_id, revision=revision, extension=".safetensors"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
def _cached_adapter_weight_files(
|
|
||||||
adapter_id: str, revision: Optional[str], extension: str
|
|
||||||
) -> List[str]:
|
|
||||||
"""Guess weight files from the cached revision snapshot directory"""
|
|
||||||
d = _get_cached_revision_directory(adapter_id, revision)
|
|
||||||
if not d:
|
|
||||||
return []
|
|
||||||
filenames = _adapter_weight_files_from_dir(d, extension)
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _cached_weight_files(
|
def _cached_weight_files(
|
||||||
model_id: str, revision: Optional[str], extension: str
|
model_id: str, revision: Optional[str], extension: str
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||||
if f.endswith(extension)
|
if f.endswith(extension)
|
||||||
and "arguments" not in f
|
and "arguments" not in f
|
||||||
and "args" not in f
|
and "args" not in f
|
||||||
and "adapter" not in f
|
|
||||||
and "training" not in f
|
and "training" not in f
|
||||||
]
|
]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|
||||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
||||||
# see _weight_files_from_dir, that's also what is done there
|
|
||||||
root, _, files = next(os.walk(str(d)))
|
|
||||||
filenames = [
|
|
||||||
os.path.join(root, f)
|
|
||||||
for f in files
|
|
||||||
if f.endswith(extension)
|
|
||||||
and "arguments" not in f
|
|
||||||
and "args" not in f
|
|
||||||
and "training" not in f
|
|
||||||
]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
|
||||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
||||||
# see _weight_files_from_dir, that's also what is done there
|
|
||||||
root, _, files = next(os.walk(str(d)))
|
|
||||||
filenames = [
|
|
||||||
os.path.join(root, f)
|
|
||||||
for f in files
|
|
||||||
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
|
||||||
]
|
|
||||||
return filenames
|
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_revision_directory(
|
def _get_cached_revision_directory(
|
||||||
model_id: str, revision: Optional[str]
|
model_id: str, revision: Optional[str]
|
||||||
) -> Optional[Path]:
|
) -> Optional[Path]:
|
||||||
|
|
Loading…
Reference in New Issue