diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 89d5bdf5..d415f369 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -32,10 +32,6 @@ jobs: permissions: contents: write packages: write - # This is used to complete the identity challenge - # with sigstore/fulcio when running outside of PRs. - id-token: write - security-events: write steps: - name: Checkout repository uses: actions/checkout@v4 @@ -50,6 +46,7 @@ jobs: export label_extension="" export docker_devices="" export runs_on="aws-g6-12xlarge-plus-priv" + export platform="" ;; rocm) export dockerfile="Dockerfile_amd" @@ -58,12 +55,21 @@ jobs: # TODO Re-enable when they pass. # export runs_on="amd-gpu-tgi" export runs_on="ubuntu-latest" + export platform="" ;; - intel) + intel-xpu) export dockerfile="Dockerfile_intel" - export label_extension="-intel" + export label_extension="-intel-xpu" export docker_devices="" export runs_on="ubuntu-latest" + export platform="xpu" + ;; + intel-cpu) + export dockerfile="Dockerfile_intel" + export label_extension="-intel-cpu" + export docker_devices="" + export runs_on="ubuntu-latest" + export platform="cpu" ;; esac echo $dockerfile @@ -71,8 +77,10 @@ jobs: echo $label_extension echo $docker_devices echo $runs_on + echo $platform echo "DOCKERFILE=${dockerfile}" >> $GITHUB_ENV echo "LABEL=${label_extension}" >> $GITHUB_ENV + echo "PLATFORM=${platform}" >> $GITHUB_ENV echo "DOCKER_DEVICES=${docker_devices}" >> $GITHUB_ENV echo "RUNS_ON=${runs_on}" >> $GITHUB_ENV echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV @@ -139,6 +147,7 @@ jobs: build-args: | GIT_SHA=${{ env.GITHUB_SHA }} DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} + PLATFORM=${{ env.PLATFORM }} tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min @@ -159,7 +168,7 @@ jobs: group: ${{ needs.build-and-push.outputs.runs_on }} if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: - PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} + PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }} steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/.github/workflows/build_pr_documentation.yaml b/.github/workflows/build_pr_documentation.yaml index bf03bfdf..a5ce39a5 100644 --- a/.github/workflows/build_pr_documentation.yaml +++ b/.github/workflows/build_pr_documentation.yaml @@ -11,7 +11,7 @@ concurrency: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yaml@main + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/.github/workflows/ci_build.yaml b/.github/workflows/ci_build.yaml index 5ca2854a..5190f321 100644 --- a/.github/workflows/ci_build.yaml +++ b/.github/workflows/ci_build.yaml @@ -37,8 +37,11 @@ jobs: # fail-fast is true by default fail-fast: false matrix: - hardware: ["cuda", "rocm", "intel"] + hardware: ["cuda", "rocm", "intel-xpu", "intel-cpu"] uses: ./.github/workflows/build.yaml # calls the one above ^ + permissions: + contents: write + packages: write with: hardware: ${{ matrix.hardware }} # https://github.com/actions/runner/issues/2206 diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f983b6ed..6faabe3b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: with: # Released on: 02 May, 2024 # https://releases.rs/docs/1.78.0/ - toolchain: 1.79.0 + toolchain: 1.80.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/.gitignore b/.gitignore index 0de8b848..edcc2f89 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ backends/client/src/v3/pb # ROCm auto-generated files *.hip -server/exllamav2_kernels/exllamav2_kernels/hip/ +server/exllamav2 server/exllama_kernels/exllama_kernels/hip/ server/exllama_kernels/exllama_kernels/hip_func/ *_hip.cuh @@ -18,3 +18,7 @@ server/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp data/ load_tests/*.json +server/fbgemmm + +.direnv/ +.venv/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6f5e685e..0c8b6885 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - exclude: docs/source/basic_tutorials/launcher.md + exclude: docs/source/reference/launcher.md - repo: https://github.com/psf/black rev: 24.2.0 hooks: diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml index 382c9ab6..13b80497 100644 --- a/.redocly.lint-ignore.yaml +++ b/.redocly.lint-ignore.yaml @@ -77,3 +77,4 @@ docs/openapi.json: - '#/paths/~1tokenize/post' - '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1completions/post' + - '#/paths/~1v1~1models/get' diff --git a/Cargo.lock b/Cargo.lock index 92367d1e..00c7f005 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -28,7 +34,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -121,14 +127,14 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-rustls" @@ -160,7 +166,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -171,7 +177,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -180,6 +186,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -246,9 +263,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" +checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703" dependencies = [ "bindgen", "cc", @@ -391,7 +408,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -433,7 +450,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.72", + "syn 2.0.76", "which", ] @@ -472,9 +489,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.5.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" +checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" [[package]] name = "block-buffer" @@ -505,9 +522,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" [[package]] name = "byteorder" @@ -523,15 +540,15 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "camino" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ "serde", ] @@ -566,13 +583,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" [[package]] -name = "cc" -version = "1.1.7" +name = "cast" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cc" +version = "1.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -606,6 +630,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clang-sys" version = "1.8.1" @@ -619,9 +649,20 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.11" +version = "2.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "bitflags 1.3.2", + "textwrap", + "unicode-width", +] + +[[package]] +name = "clap" +version = "4.5.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" dependencies = [ "clap_builder", "clap_derive", @@ -629,9 +670,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.11" +version = "4.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" dependencies = [ "anstream", "anstyle", @@ -641,14 +682,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.11" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" +checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -659,9 +700,9 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] @@ -713,15 +754,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" dependencies = [ "libc", ] @@ -735,6 +776,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +dependencies = [ + "atty", + "cast", + "clap 2.34.0", + "criterion-plot", + "csv", + "itertools 0.10.5", + "lazy_static", + "num-traits", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_cbor", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -833,19 +910,19 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.4" +version = "3.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" dependencies = [ - "nix", - "windows-sys 0.52.0", + "nix 0.29.0", + "windows-sys 0.59.0", ] [[package]] name = "cxx" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +checksum = "3c4eae4b7fc8dcb0032eb3b1beee46b38d371cdeaf2d0c64b9944f6f69ad7755" dependencies = [ "cc", "cxxbridge-flags", @@ -855,9 +932,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +checksum = "6c822bf7fb755d97328d6c337120b6f843678178751cba33c9da25cf522272e0" dependencies = [ "cc", "codespan-reporting", @@ -865,24 +942,24 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "cxxbridge-flags" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" +checksum = "719d6197dc016c88744aff3c0d0340a01ecce12e8939fc282e7c8f583ee64bc6" [[package]] name = "cxxbridge-macro" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +checksum = "35de3b547387863c8f82013c4f79f1c2162edee956383e4089e1d04c18c4f16c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -906,7 +983,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -917,7 +994,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -947,7 +1024,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -957,7 +1034,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -993,9 +1070,9 @@ dependencies = [ [[package]] name = "dunce" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "easy-cast" @@ -1060,9 +1137,9 @@ checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" dependencies = [ "bit_field", "flume", - "half", + "half 2.4.1", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.4", "rayon-core", "smallvec", "zune-inflate", @@ -1080,9 +1157,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" @@ -1101,12 +1178,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -1232,7 +1309,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1341,7 +1418,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", @@ -1350,9 +1427,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", "bytes", @@ -1360,13 +1437,19 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "half" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" + [[package]] name = "half" version = "2.4.1" @@ -1404,6 +1487,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1552,7 +1644,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -1610,9 +1702,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1695,9 +1787,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -1753,7 +1845,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1804,6 +1896,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1827,9 +1928,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -1844,7 +1945,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap", + "clap 4.5.16", "fancy-regex", "fraction", "getrandom", @@ -1884,9 +1985,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "libfuzzer-sys" @@ -2038,7 +2139,7 @@ dependencies = [ "hyper 1.4.1", "hyper-rustls", "hyper-util", - "indexmap 2.2.6", + "indexmap 2.4.0", "ipnet", "metrics", "metrics-util", @@ -2081,18 +2182,19 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f7e8e35b6c7b169bf40b0176d2c79291ab8ee53290b84e0668ab21d841aa9d" +checksum = "6d7d3e3a3eece1fa4618237ad41e1de855ced47eab705cec1c9a920e1d1c5aad" dependencies = [ "serde", + "serde_json", ] [[package]] name = "minijinja-contrib" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" +checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b" dependencies = [ "minijinja", "serde", @@ -2114,6 +2216,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "0.8.11" @@ -2128,11 +2239,11 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", "wasi", "windows-sys 0.52.0", @@ -2162,7 +2273,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2252,7 +2363,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -2350,7 +2473,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2400,7 +2523,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.9", "libc", ] @@ -2421,9 +2544,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.2" +version = "0.36.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" dependencies = [ "memchr", ] @@ -2456,6 +2579,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "openssl" version = "0.10.66" @@ -2479,7 +2608,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2518,7 +2647,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.2.6", + "indexmap 2.4.0", "js-sys", "once_cell", "pin-project-lite", @@ -2742,7 +2871,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.4.0", ] [[package]] @@ -2762,7 +2891,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2783,6 +2912,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.17.13" @@ -2793,7 +2950,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.7.4", ] [[package]] @@ -2810,21 +2967,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.6.6", + "zerocopy", ] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2876,7 +3033,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2916,7 +3073,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.72", + "syn 2.0.76", "tempfile", ] @@ -2943,7 +3100,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2987,9 +3144,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -3078,9 +3235,9 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" +checksum = "a8f0bfd976333248de2078d350bfdf182ff96e168a24d23d2436cef320dd4bdd" dependencies = [ "avif-serialize", "imgref", @@ -3141,9 +3298,9 @@ dependencies = [ [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -3152,9 +3309,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3236,9 +3393,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.45" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" +checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71" dependencies = [ "bytemuck", ] @@ -3293,7 +3450,7 @@ dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.72", + "syn 2.0.76", "walkdir", ] @@ -3384,12 +3541,12 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", "security-framework", @@ -3406,9 +3563,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -3416,15 +3573,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.6" +version = "0.102.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3518,29 +3675,39 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] [[package]] -name = "serde_derive" -version = "1.0.204" +name = "serde_cbor" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" +dependencies = [ + "half 1.8.3", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.209" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "serde_json" -version = "1.0.121" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", @@ -3742,7 +3909,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3764,9 +3931,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" dependencies = [ "proc-macro2", "quote", @@ -3860,20 +4027,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3891,11 +4059,12 @@ version = "2.2.1-dev0" dependencies = [ "async-stream", "async-trait", - "clap", + "clap 4.5.16", "cmake", "cxx", "cxx-build", "log", + "parking_lot", "pkg-config", "text-generation-router", "thiserror", @@ -3912,7 +4081,7 @@ name = "text-generation-benchmark" version = "2.2.1-dev0" dependencies = [ "average", - "clap", + "clap 4.5.16", "crossterm", "float-ord", "hf-hub", @@ -3950,11 +4119,11 @@ dependencies = [ name = "text-generation-launcher" version = "2.2.1-dev0" dependencies = [ - "clap", + "clap 4.5.16", "ctrlc", "float_eq", "hf-hub", - "nix", + "nix 0.28.0", "once_cell", "reqwest", "serde", @@ -3974,7 +4143,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.16", "csv", "futures", "futures-util", @@ -4022,13 +4191,15 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap", + "clap 4.5.16", + "criterion", "futures", "futures-util", "grpc-metadata", "hf-hub", "image", "init-tracing-opentelemetry", + "itertools 0.13.0", "jsonschema", "metrics", "metrics-exporter-prometheus", @@ -4045,6 +4216,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "slotmap", "text-generation-router", "thiserror", "tokenizers", @@ -4061,6 +4233,15 @@ dependencies = [ "utoipa-swagger-ui", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -4078,7 +4259,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4135,6 +4316,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -4185,14 +4376,14 @@ dependencies = [ [[package]] name = "tokio" -version = "1.39.2" +version = "1.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" dependencies = [ "backtrace", "bytes", "libc", - "mio 1.0.1", + "mio 1.0.2", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -4219,7 +4410,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4281,9 +4472,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -4293,20 +4484,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.17" +version = "0.22.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" +checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_spanned", "toml_datetime", @@ -4378,7 +4569,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4419,15 +4610,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4449,7 +4640,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4709,7 +4900,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_json", "utoipa-gen", @@ -4725,7 +4916,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4763,7 +4954,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4844,34 +5035,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -4881,9 +5073,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4891,28 +5083,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", @@ -4993,11 +5185,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5052,6 +5244,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -5232,9 +5433,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.16" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" dependencies = [ "memchr", ] @@ -5249,34 +5450,14 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "zerocopy" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" -dependencies = [ - "byteorder", - "zerocopy-derive 0.6.6", -] - [[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy-derive" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", + "byteorder", + "zerocopy-derive", ] [[package]] @@ -5287,7 +5468,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5307,7 +5488,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8bf75b90..79fda15d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,8 @@ tokenizers = { version = "0.19.1", features = ["http"] } hf-hub = { version = "0.3.1", features = ["tokio"] } metrics = { version = "0.23.0" } metrics-exporter-prometheus = { version = "0.15.1", features = [] } +minijinja = { version = "2.2.0", features = ["json"] } +minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } [profile.release] incremental = true diff --git a/Dockerfile b/Dockerfile index 0d57e38d..0d0e89b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -40,14 +40,14 @@ RUN cargo build --profile release-opt # Python builder # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install # NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099 ARG PYTORCH_VERSION=2.4.0 ARG PYTHON_VERSION=3.10 # Keep in sync with `server/pyproject.toml -ARG CUDA_VERSION=12.1 +ARG CUDA_VERSION=12.4 ARG MAMBA_VERSION=24.3.0-0 ARG CUDA_CHANNEL=nvidia ARG INSTALL_CHANNEL=pytorch @@ -88,6 +88,7 @@ RUN case ${TARGETPLATFORM} in \ FROM pytorch-install AS kernel-builder ARG MAX_JOBS=8 +ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0+PTX" RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ ninja-build cmake \ @@ -118,29 +119,29 @@ FROM kernel-builder AS exllama-kernels-builder WORKDIR /usr/src COPY server/exllama_kernels/ . -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN python setup.py build # Build Transformers exllama kernels FROM kernel-builder AS exllamav2-kernels-builder WORKDIR /usr/src -COPY server/exllamav2_kernels/ . +COPY server/Makefile-exllamav2/ Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build +RUN make build-exllamav2 # Build Transformers awq kernels FROM kernel-builder AS awq-kernels-builder WORKDIR /usr/src COPY server/Makefile-awq Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq +RUN make build-awq # Build eetq kernels FROM kernel-builder AS eetq-kernels-builder WORKDIR /usr/src COPY server/Makefile-eetq Makefile # Build specific version of transformers -RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq +RUN make build-eetq # Build Lorax Punica kernels FROM kernel-builder AS lorax-punica-builder @@ -183,6 +184,12 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all +# Build flashinfer +FROM kernel-builder AS flashinfer-builder +WORKDIR /usr/src +COPY server/Makefile-flashinfer Makefile +RUN make install-flashinfer + # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base @@ -191,7 +198,7 @@ ENV PATH=/opt/conda/bin:$PATH \ CONDA_PREFIX=/opt/conda # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -221,11 +228,13 @@ COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 / # Copy build artifacts from exllama kernels builder COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from exllamav2 kernels builder -COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +COPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from awq kernels builder COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from eetq kernels builder COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from lorax punica kernels builder +COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages # Copy build artifacts from fbgemm builder COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages # Copy build artifacts from vllm builder @@ -233,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/ # Install flash-attention dependencies RUN pip install einops --no-cache-dir @@ -248,6 +258,9 @@ RUN cd server && \ pip install nvidia-nccl-cu12==2.22.3 ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2 +# This is needed because exl2 tries to load flash-attn +# And fails with our builds. +ENV EXLLAMA_NO_FLASH_ATTN=1 # Deps before the binaries # The binaries change on every build given we burn the SHA into them diff --git a/Dockerfile_amd b/Dockerfile_amd index d8f16e7e..1940b985 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -199,7 +199,7 @@ RUN python setup.py build FROM base AS base-copy # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 diff --git a/Dockerfile_intel b/Dockerfile_intel index d20f0a01..0cda4d4b 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -57,7 +57,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils # Text Generation Inference base env -ENV HUGGINGFACE_HUB_CACHE=/data \ +ENV HF_HOME=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ PORT=80 @@ -106,7 +106,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins g++ \ git \ wget \ - cmake + cmake \ + libnuma-dev ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ @@ -135,7 +136,7 @@ RUN conda install -c conda-forge gperftools mkl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl -RUN pip install triton +RUN pip install triton numa WORKDIR /usr/src @@ -147,16 +148,11 @@ RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . -ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib -ENV KMP_BLOCKTIME=1 -ENV KMP_TPAUSE=0 -ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist -ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist -ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist # Install server COPY proto proto @@ -175,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/README.md b/README.md index a88e0437..cc9d523f 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Swagger API documentation -A Rust, Python and gRPC server for text generation inference. Used in production at [HuggingFace](https://huggingface.co) +A Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co) to power Hugging Chat, the Inference API and Inference Endpoint. @@ -42,12 +42,15 @@ Text Generation Inference (TGI) is a toolkit for deploying and serving Large Lan - Tensor Parallelism for faster inference on multiple GPUs - Token streaming using Server-Sent Events (SSE) - Continuous batching of incoming requests for increased total throughput +- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API - Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures - Quantization with : - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) - [GPT-Q](https://arxiv.org/abs/2210.17323) - [EETQ](https://github.com/NetEase-FuXi/EETQ) - [AWQ](https://github.com/casper-hansen/AutoAWQ) + - [Marlin](https://github.com/IST-DASLab/marlin) + - [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) - [Safetensors](https://github.com/huggingface/safetensors) weight loading - Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor)) @@ -92,6 +95,29 @@ curl 127.0.0.1:8080/generate_stream \ -H 'Content-Type: application/json' ``` +You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses. + +```bash +curl localhost:3000/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + **Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar. **Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.2.0-rocm --model-id $model` instead of the command above. @@ -120,7 +146,7 @@ For example, if you want to serve the gated Llama V2 model variants: or with Docker: ```shell -model=meta-llama/Llama-2-7b-chat-hf +model=meta-llama/Meta-Llama-3.1-8B-Instruct volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run token= @@ -163,6 +189,8 @@ overridden with the `--otlp-service-name` argument ![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png) +Detailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi) + ### Local install You can also opt to install `text-generation-inference` locally. @@ -232,7 +260,7 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 ### Quantization -You can also quantize the weights with bitsandbytes to reduce the VRAM requirement: +You can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement: ```shell text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize @@ -240,6 +268,8 @@ text-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantiz 4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`. +Read more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization). + ## Develop ```shell diff --git a/_server.nix b/_server.nix new file mode 100644 index 00000000..2cb2f887 --- /dev/null +++ b/_server.nix @@ -0,0 +1,17 @@ +{ + mkPoetryApplication, + pkg-config, + protobuf, + openssl, +}: + +mkPoetryApplication { + # name = "text-generation-server"; + + projectDir = ./server; + + # nativeBuildInputs = [ pkg-config ]; + + # buildInputs = [ openssl.dev protobuf ]; + +} diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index a996b14f..479d31bf 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -153,9 +153,12 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Most request will have that + add_special_tokens: true, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index ae8a899b..645c076a 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -221,6 +221,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -244,6 +245,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/trtllm/Cargo.toml b/backends/trtllm/Cargo.toml index 7079d3d1..43a114ba 100644 --- a/backends/trtllm/Cargo.toml +++ b/backends/trtllm/Cargo.toml @@ -8,17 +8,18 @@ homepage.workspace = true [dependencies] async-trait = "0.1" async-stream = "0.3" +clap = { version = "4.5", features = ["derive"] } cxx = "1.0" +log = { version = "0.4", features = [] } text-generation-router = { path = "../../router" } tokenizers = { version = "0.19", features = ["hf-hub"] } tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio-stream = "0.1.15" -clap = { version = "4.5", features = ["derive"] } thiserror = "1.0.62" tracing = "0.1" tracing-opentelemetry = "0.24" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -log = { version = "0.4", features = [] } +parking_lot = "0.12" [build-dependencies] cmake = "0.1" diff --git a/backends/trtllm/Dockerfile b/backends/trtllm/Dockerfile index 60ad03f7..5fd2f89f 100644 --- a/backends/trtllm/Dockerfile +++ b/backends/trtllm/Dockerfile @@ -3,7 +3,7 @@ ARG OMPI_VERSION="4.1.6" # Build dependencies resolver stage FROM lukemathwalker/cargo-chef:latest AS chef -WORKDIR /usr/src/text-generation-inference +WORKDIR /usr/src/text-generation-inference/backends/trtllm FROM chef AS planner COPY . . @@ -42,7 +42,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE mkdir /usr/src/mpi && \ tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \ cd /usr/src/mpi && \ - ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \ + ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \ make -j all && \ make install && \ rm -rf "/opt/src/$OMPI_TARBALL_FILENAME" @@ -66,7 +66,7 @@ ENV PATH="/root/.cargo/bin:$PATH" RUN cargo install cargo-chef # Cache dependencies -COPY --from=planner /usr/src/text-generation-inference/recipe.json . +COPY --from=planner /usr/src/text-generation-inference/backends/trtllm/recipe.json . RUN cargo chef cook --release --recipe-path recipe.json # Build actual TGI @@ -79,7 +79,8 @@ COPY . . COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \ - CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm + cd backends/trtllm && \ + CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime WORKDIR /usr/local/tgi/bin diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index b26d06a6..b23aa6c0 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -12,12 +12,13 @@ use cxx::UniquePtr; use log::{error, warn}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::sync::RwLock; use tokio::time::{sleep, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::{Stream, StreamExt}; use tracing::{instrument, span, Level}; +// use tokio::sync::RwLock; +use parking_lot::RwLock; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidationError::UnsupportedModality; use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError}; diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 6d6ee146..e0ba46c7 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -1,12 +1,10 @@ +use clap::Parser; use std::collections::HashMap; use std::path::PathBuf; - -use clap::Parser; -use tokenizers::{FromPretrainedParameters, Tokenizer}; - use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_router::server; +use tokenizers::{FromPretrainedParameters, Tokenizer}; /// App Configuration #[derive(Parser, Debug)] @@ -160,6 +158,8 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { messages_api_enabled, true, max_client_batch_size, + false, + false, ) .await?; Ok(()) diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 5d9a140b..69dad072 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -33,9 +33,16 @@ rand = "0.8.5" reqwest = { version = "0.11.20", features = [] } serde = "1.0.188" serde_json = "1.0.107" +slotmap = "1.0.7" thiserror = "1.0.48" -tokenizers = { workspace = true} -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } +tokenizers = { workspace = true } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "parking_lot", + "signal", + "sync", +] } tokio-stream = "0.1.14" tower-http = { version = "0.5.1", features = ["cors"] } tracing = "0.1.37" @@ -43,9 +50,11 @@ tracing-opentelemetry = "0.21.0" tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] } utoipa = { version = "4.2.0", features = ["axum_extras"] } utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] } -init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] } -minijinja = { version = "2.0.2" } -minijinja-contrib = { version = "2.0.2", features = ["pycompat"] } +init-tracing-opentelemetry = { version = "0.14.1", features = [ + "opentelemetry-otlp", +] } +minijinja = { workspace = true } +minijinja-contrib = { workspace = true } futures-util = "0.3.30" regex = "1.10.3" once_cell = "1.19.0" @@ -59,8 +68,16 @@ tower = "^0.4" tonic-build = "0.10.1" prost-build = "0.12.1" +[dev-dependencies] +criterion = "0.3" +itertools = "0.13" + [features] default = ["ngrok"] ngrok = ["text-generation-router/ngrok"] google = ["text-generation-router/google"] kserve = ["text-generation-router/kserve"] + +[[bench]] +name = "prefix_cache" +harness = false diff --git a/backends/v3/benches/prefix_cache.rs b/backends/v3/benches/prefix_cache.rs new file mode 100644 index 00000000..d9df45b2 --- /dev/null +++ b/backends/v3/benches/prefix_cache.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::Rng; + +use text_generation_router_v3::block_allocator::Allocator; +use text_generation_router_v3::radix::RadixAllocator; + +fn prefix_cache_benchmark(c: &mut Criterion) { + // let prefixes: Vec> = (0..8192) + // .chunks(256) + // .into_iter() + // .map(|c| c.collect()) + // .collect(); + + let mut cache = RadixAllocator::new(1, 262144, None); + + c.bench_function("Radix allocator", |b| { + b.iter_batched( + || { + //prefixes + // .choose_multiple(&mut rand::thread_rng(), 5) + // .fold(Vec::new(), |mut v, s| { + // v.extend(s); + // v + // }) + + (0..7936) + .map(|_| rand::thread_rng().gen_range(0..1024)) + .collect::>() + }, + |prefill| { + let alloc = cache.allocate( + prefill.len() as u32 + 13, + Some(Arc::new(black_box(prefill))), + ); + if let Some(alloc) = alloc { + cache.free(alloc.blocks.clone(), alloc.allocation_id); + } + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!(benches, prefix_cache_benchmark); +criterion_main!(benches); diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index d82355de..935f7980 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -6,7 +6,7 @@ use nohash_hasher::IntMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; -use text_generation_router::{FinishReason, PrefillToken, Token}; +use text_generation_router::{Attention, FinishReason, PrefillToken, Token}; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; @@ -35,16 +35,20 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { - matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") - } else { - false - }; - let block_size = if flashdecoding { 256 } else { 16 }; + let prefix_caching = + std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); + let attention: String = std::env::var("ATTENTION").expect("attention env var"); + + let attention: Attention = attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); + let block_size = attention.block_size(); let queue = Queue::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -164,11 +168,14 @@ pub(crate) async fn batching_task( None } else { // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. Some((batch_size as f32 * waiting_served_ratio).floor() as usize) }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 7467fd85..4fea172b 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,21 +1,31 @@ -use std::cmp::min; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; +use crate::radix::RadixAllocator; + #[derive(Debug, Clone)] -pub(crate) struct BlockAllocation { +pub struct BlockAllocation { + pub allocation_id: u64, pub blocks: Vec, pub slots: Vec, - block_allocator: BlockAllocator, + + /// Prefix that was cached and for which the KV does not have to + /// be recomputed. + pub prefix_len: u32, + + pub(crate) block_allocator: Option, } impl Drop for BlockAllocation { fn drop(&mut self) { - self.block_allocator.free(self.blocks.clone()) + if let Some(block_allocator) = self.block_allocator.as_mut() { + block_allocator.free(self.blocks.clone(), self.allocation_id) + } } } #[derive(Debug, Clone)] -pub(crate) struct BlockAllocator { +pub struct BlockAllocator { /// Channel to communicate with the background task block_allocator: mpsc::UnboundedSender, } @@ -24,6 +34,7 @@ impl BlockAllocator { pub(crate) fn new( max_batch_total_tokens: u32, block_size: u32, + prefix_caching: bool, window_size: Option, ) -> Self { // Create channel @@ -33,6 +44,7 @@ impl BlockAllocator { tokio::spawn(block_allocator_task( max_batch_total_tokens / block_size, block_size, + prefix_caching, window_size, receiver, )); @@ -42,28 +54,32 @@ impl BlockAllocator { } } - pub(crate) async fn allocate(&self, tokens: u32) -> Option { + pub(crate) async fn allocate( + &self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { let (response_sender, response_receiver) = oneshot::channel(); self.block_allocator .send(BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, }) .unwrap(); - response_receiver - .await - .unwrap() - .map(|(blocks, slots)| BlockAllocation { - blocks, - slots, - block_allocator: self.clone(), - }) + response_receiver.await.unwrap().map(|mut allocation| { + allocation.block_allocator = Some(self.clone()); + allocation + }) } - pub(crate) fn free(&self, blocks: Vec) { + pub(crate) fn free(&self, blocks: Vec, allocation_id: u64) { self.block_allocator - .send(BlockAllocatorCommand::Free { blocks }) + .send(BlockAllocatorCommand::Free { + allocation_id, + blocks, + }) .unwrap(); } } @@ -71,54 +87,29 @@ impl BlockAllocator { async fn block_allocator_task( blocks: u32, block_size: u32, + prefix_caching: bool, window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - // Block 0 is reserved for health checks - let mut free_blocks: Vec = (1..blocks).collect(); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { - BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks), + BlockAllocatorCommand::Free { + blocks, + allocation_id, + } => allocator.free(blocks, allocation_id), BlockAllocatorCommand::Allocate { tokens, + prefill_tokens, response_sender, } => { - // Apply window size - let (required_blocks, repeats) = { - let (tokens, repeats) = match window_size { - None => (tokens, 1), - Some(window_size) => { - let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); - (tokens, repeats as usize) - } - }; - // Pad to a multiple of block size - let required_blocks = (tokens + block_size - 1) / block_size; - (required_blocks, repeats) - }; - - let tokens = tokens as usize; - let allocation = if required_blocks > free_blocks.len() as u32 { - None - } else { - let blocks = - free_blocks.split_off(free_blocks.len() - required_blocks as usize); - let mut slots = Vec::with_capacity( - (required_blocks * block_size * repeats as u32) as usize, - ); - - 'slots: for block_id in blocks.repeat(repeats).iter() { - for s in (block_id * block_size)..((block_id + 1) * block_size) { - slots.push(s); - if slots.len() == tokens { - break 'slots; - } - } - } - Some((blocks, slots)) - }; - response_sender.send(allocation).unwrap(); + response_sender + .send(allocator.allocate(tokens, prefill_tokens)) + .unwrap(); } } } @@ -128,9 +119,91 @@ async fn block_allocator_task( enum BlockAllocatorCommand { Free { blocks: Vec, + allocation_id: u64, }, Allocate { tokens: u32, - response_sender: oneshot::Sender, Vec)>>, + prefill_tokens: Option>>, + response_sender: oneshot::Sender>, }, } + +pub trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = core::cmp::min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index c407687b..648662db 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -149,6 +149,7 @@ impl Client { requests.push(Request { id: 0, inputs, + add_special_tokens: true, input_chunks: Some(Input { chunks: input_chunks, }), @@ -157,6 +158,7 @@ impl Client { // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], + prefix_len: 0, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index afb13cdc..ea77a696 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -222,6 +222,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -245,6 +246,7 @@ impl Health for ShardedClient { // Block 0 is reserved for health checks blocks: vec![0], slots: (0..16).collect(), + prefix_len: 0, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index a6f89169..77a9a11a 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -1,7 +1,8 @@ mod backend; -mod block_allocator; +pub mod block_allocator; mod client; mod queue; +pub mod radix; use crate::client::{ClientError, ShardedClient}; pub(crate) use backend::BackendV3; diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 21952e66..471ddb5a 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -150,6 +150,14 @@ async fn main() -> Result<(), RouterError> { } } + if let Some(max_batch_size) = max_batch_size { + if max_batch_size == 0 { + return Err(RouterError::ArgumentValidation( + "`max_batch_size` must be > 0".to_string(), + )); + } + } + let (backend, _backend_info) = connect_backend( max_input_tokens, max_total_tokens, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 9427bd60..978a495c 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -46,6 +46,7 @@ impl Queue { pub(crate) fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -57,6 +58,7 @@ impl Queue { tokio::spawn(queue_task( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -109,6 +111,7 @@ impl Queue { async fn queue_task( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, @@ -117,6 +120,7 @@ async fn queue_task( let mut state = State::new( requires_padding, block_size, + prefix_caching, window_size, speculate, max_batch_total_tokens, @@ -176,12 +180,19 @@ impl State { fn new( requires_padding: bool, block_size: u32, + prefix_caching: bool, window_size: Option, speculate: u32, max_batch_total_tokens: u32, ) -> Self { - let block_allocator = (!requires_padding) - .then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size)); + let block_allocator = (!requires_padding).then(|| { + BlockAllocator::new( + max_batch_total_tokens, + block_size, + prefix_caching, + window_size, + ) + }); Self { entries: VecDeque::with_capacity(128), @@ -226,25 +237,29 @@ impl State { } } + if let Some(max_size) = max_size { + if max_size == 0 { + tracing::debug!("No capacity"); + return None; + } + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; // Create span for this batch to add context to inference calls let next_batch_span = info_span!(parent: None, "batch", batch_size = tracing::field::Empty); - next_batch_span.follows_from(&Span::current()); - - let mut batch_requests = Vec::with_capacity(self.entries.len()); - let mut batch_entries = - IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + next_batch_span.follows_from(Span::current()); + let mut batch = Vec::with_capacity(self.entries.len()); let mut max_input_length = 0; let mut prefill_tokens: u32 = 0; let mut decode_tokens: u32 = 0; let mut max_blocks = 0; // Pop entries starting from the front of the queue - 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() { + 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request // was dropped by the client) if entry.response_tx.is_closed() { @@ -258,7 +273,7 @@ impl State { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation max_input_length = max_input_length.max(entry.request.input_length); - prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length; + prefill_tokens = (batch.len() + 1) as u32 * max_input_length; decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate; @@ -272,7 +287,7 @@ impl State { } None } - Some(block_allocator) => { + Some(_block_allocator) => { prefill_tokens += entry.request.input_length; let max_new_tokens = match self.window_size { None => entry.request.stopping_parameters.max_new_tokens, @@ -298,23 +313,67 @@ impl State { + self.speculate - 1; - match block_allocator.allocate(tokens).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - break 'entry_loop; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } + // If users wants the prefill logprobs, we cannot reuse the cache. + // So no input_ids for the radix tree. + let input_ids = if entry.request.decoder_input_details { + None + } else { + entry.request.input_ids.clone() + }; + + Some((tokens, input_ids)) } }; + batch.push((id, entry, block_allocation)); + if Some(batch.len()) == max_size { + break; + } + } + // Empty batch + if batch.is_empty() { + tracing::debug!("Filterered out all entries"); + return None; + } + + // XXX We haven't allocated yet, so we're allowed to ditch the results. + // Check if our batch is big enough + if let Some(min_size) = min_size { + // Batch is too small + if batch.len() < min_size { + // Add back entries to the queue in the correct order + for (id, entry, _) in batch.into_iter().rev() { + self.entries.push_front((id, entry)); + } + return None; + } + } + + let mut batch_requests = Vec::with_capacity(self.entries.len()); + let mut batch_entries = + IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); + + for (id, mut entry, block_allocation) in batch { + let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = + (block_allocation, &self.block_allocator) + { + match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break; + } + Some(block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + Some(block_allocation) + } + } + } else { + None + }; tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); @@ -324,11 +383,12 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), + let (blocks, slots, prefix_len) = match &block_allocation { + None => (Vec::new(), Vec::new(), 0), Some(block_allocation) => ( block_allocation.blocks.clone(), block_allocation.slots.clone(), + block_allocation.prefix_len, ), }; @@ -356,6 +416,7 @@ impl State { }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), @@ -365,38 +426,13 @@ impl State { top_n_tokens: entry.request.top_n_tokens, blocks, slots, + prefix_len, adapter_id: entry.request.adapter_id.clone(), }); // Set batch_time entry.batch_time = Some(Instant::now()); // Insert in batch_entries IntMap batch_entries.insert(id, entry); - - // Check if max_size - if Some(batch_requests.len()) == max_size { - break; - } - } - - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - - // Check if our batch is big enough - if let Some(min_size) = min_size { - // Batch is too small - if batch_requests.len() < min_size { - // Add back entries to the queue in the correct order - for r in batch_requests.into_iter().rev() { - let id = r.id; - let entry = batch_entries.remove(&id).unwrap(); - self.entries.push_front((id, entry)); - } - - return None; - } } // Final batch size @@ -473,6 +509,8 @@ impl From for StoppingCriteriaParameters { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use tracing::info_span; @@ -485,7 +523,9 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], + input_ids: Some(Arc::new(vec![])), input_length: 0, + add_special_tokens: true, truncate: 0, decoder_input_details: false, parameters: ValidParameters { @@ -520,7 +560,7 @@ mod tests { #[tokio::test] async fn test_append() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -536,7 +576,7 @@ mod tests { #[tokio::test] async fn test_next_batch_empty() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); assert!(state.next_batch(None, None, 1, 1).await.is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -544,7 +584,7 @@ mod tests { #[tokio::test] async fn test_next_batch_min_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -576,7 +616,7 @@ mod tests { #[tokio::test] async fn test_next_batch_max_size() { - let mut state = State::new(false, 1, None, 0, 16); + let mut state = State::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -596,7 +636,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, false, None, 0, 2); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -629,14 +669,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); @@ -644,7 +684,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -677,7 +717,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_max_size() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -693,7 +733,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -718,7 +758,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(false, 1, false, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -737,7 +777,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false, 1, None, 0, 16); + let queue = Queue::new(false, 1, false, None, 0, 16); let (entry, _) = default_entry(); queue.append(entry); diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs new file mode 100644 index 00000000..1f3bef15 --- /dev/null +++ b/backends/v3/src/radix.rs @@ -0,0 +1,850 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; + +pub struct RadixAllocator { + allocation_id: u64, + + allocations: HashMap, + + cache_blocks: RadixTrie, + + /// Blocks that are immediately available for allocation. + free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, +} + +impl RadixAllocator { + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + RadixAllocator { + allocation_id: 0, + allocations: HashMap::new(), + cache_blocks: RadixTrie::new(block_size as usize), + + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), + window_size, + block_size, + } + } + + fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option> { + if self.free_blocks.len() < n_blocks_needed { + // This is a bit annoying, we first extend the free list and then + // split it off again below. This is because we need to put it on + // the free list if we cannot allocate enough blocks. This is only + // temporary, the trie needs to be able to report whether it can + // allocate the requested amount. Just not implemented yet. + self.free_blocks.extend( + self.cache_blocks + .evict(n_blocks_needed - self.free_blocks.len()), + ); + } + + if self.free_blocks.len() >= n_blocks_needed { + Some( + self.free_blocks + .split_off(self.free_blocks.len() - n_blocks_needed), + ) + } else { + None + } + } +} + +// Allocator trait +impl Allocator for RadixAllocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option { + let mut blocks = vec![]; + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + node_id + } else { + self.cache_blocks.root_id() + }; + + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. + self.cache_blocks + .incref(prefix_node) + .expect("Failed to increment refcount"); + + let prefix_len = blocks.len() * self.block_size as usize; + let suffix_len = tokens - prefix_len as u32; + + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + tracing::info!("Prefix {prefix_len} - Suffix {suffix_len}"); + + match self.alloc_or_reclaim(suffix_blocks as usize) { + Some(suffix_blocks) => blocks.extend(suffix_blocks), + None => { + self.cache_blocks + .decref(prefix_node) + .expect("Failed to decrement refcount"); + return None; + } + } + + // 1:1 mapping of blocks and slots. + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; + + let allocation = RadixAllocation { + prefix_node, + cached_prefix_len: prefix_len, + prefill_tokens: prefill_tokens.clone(), + }; + + tracing::debug!("Blocks {blocks:?}"); + + self.allocation_id += 1; + self.allocations.insert(self.allocation_id, allocation); + + Some(BlockAllocation { + allocation_id: self.allocation_id, + block_allocator: None, + blocks, + slots, + prefix_len: prefix_len as u32, + }) + } + + fn free(&mut self, blocks: Vec, allocation_id: u64) { + let allocation = match self.allocations.remove(&allocation_id) { + Some(allocation) => allocation, + None => unreachable!("Tried to free an unknown allocation."), + }; + + self.cache_blocks + .decref(allocation.prefix_node) + .expect("Failed to decrement refcount"); + + if let Some(prefill_tokens) = allocation.prefill_tokens { + let prefill_tokens = prefill_tokens.as_slice(); + + // If there are prefill tokens that did not come from the cache, + // add them to the cache. + if prefill_tokens.len() > allocation.cached_prefix_len { + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } + } + } + + // Free non-prefill blocks. + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); + } else { + self.free_blocks.extend(blocks); + } + } +} + +struct RadixAllocation { + prefix_node: NodeId, + cached_prefix_len: usize, + prefill_tokens: Option>>, +} + +// Radix trie that is heavily inspired by radix attention from sglang. +// +// The trie is optimized for prefix caching: +// +// - A normal radix trie stores discrete values. In this radix trie, +// inserting *abc* with value *xyz* will also enable lookup for +// *a* (*x*) and *ab* (*xy*). +// - As a result, every value is required to have the same length as +// the key. +// - We store additional information in each node, such as last access +// time and a reference count. + +#[derive(Debug)] +pub enum TrieError { + InvalidNodeId, + RefCountUnderflow, + BlockTokenCountMismatch, +} + +pub type NodeId = DefaultKey; + +#[derive(Debug)] +pub struct RadixTrie { + /// Identifier of the root nod. + root: DefaultKey, + + /// Leave node identifiers ordered by increasing recency. + leaves: BTreeSet<(u64, NodeId)>, + + /// All trie nodes. + nodes: SlotMap, + + /// Time as a monotonically increating counter to avoid the system + /// call that a real time lookup would require. + time: u64, + + /// All blocks need to be aligned with this + block_size: usize, +} + +impl RadixTrie { + /// Construct a new radix trie. + pub fn new(block_size: usize) -> Self { + let root = TrieNode::new(vec![], vec![], 0, None); + let mut nodes = SlotMap::new(); + let root = nodes.insert(root); + RadixTrie { + leaves: BTreeSet::new(), + nodes, + root, + time: 0, + block_size, + } + } + + /// Find the prefix of the given tokens. + /// + /// The blocks corresponding to the part of the prefix that could be found + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// Returns the identifier of the trie node that contains the longest + /// prefix. The node identifier can be used by callers to e.g. increase its + /// reference count. + /// + /// Using this method will update the access time of the traversed nodes. + pub fn find(&mut self, key: &[u32], blocks: &mut Vec) -> NodeId { + self.time += 1; + self.find_(self.root, key, blocks) + } + + /// Find worker. + fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + let node = &self.nodes[node_id]; + + if let Some(&child_id) = node.children.get(&key[0]) { + self.update_access_time(child_id); + let child = self.nodes.get(child_id).expect("Invalid child identifier"); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + + let key = &key[shared_prefix_len..]; + if !key.is_empty() { + node_id = self.find_(child_id, key, blocks); + } + } + + node_id + } + + /// Decrease the reference count of a node. + pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + // We don't care about refcounting for root, since it will never + // be evicted. + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + return Err(TrieError::RefCountUnderflow); + } + + node.ref_count -= 1; + if node.ref_count == 0 { + assert!( + node.children.is_empty(), + "Nodes with children must have refcount > 0" + ); + + self.leaves.insert((node.last_accessed, node_id)); + } + + Ok(()) + } + + /// Increase the reference count of a node. + pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> { + if node_id == self.root { + return Ok(()); + } + + let node = self + .nodes + .get_mut(node_id) + .ok_or(TrieError::InvalidNodeId)?; + if node.ref_count == 0 { + self.leaves.remove(&(node.last_accessed, node_id)); + } + node.ref_count += 1; + + Ok(()) + } + + /// Evict `n_blocks` from the trie. + /// + /// Returns the evicted blocks. When the length is less than `n_blocks`, + /// not enough blocks could be evicted. + pub fn evict(&mut self, n_blocks: usize) -> Vec { + // NOTE: we don't return Result here. If any of the unwrapping fails, + // it's a programming error in the trie implementation, not a user + // error caused by e.g. an invalid argument. + + // TODO: add some bookkeeping in the future to check whether we can + // evict n_blocks and return `None` if we can't. We are now needlessly + // evicting prefixes from the cache in such a case. + let mut evicted = Vec::new(); + + while let Some((last_access, node_id)) = self.leaves.pop_first() { + let blocks_needed = n_blocks - evicted.len(); + + let node = self.nodes.get(node_id).expect("Leave does not exist"); + assert_eq!( + node.ref_count, 0, + "Leaf must have refcount of 0, got {}", + node.ref_count + ); + + if blocks_needed >= node.blocks.len() { + // We need to evict the whole node if we need more blocks than it has. + let node = self.remove_node(node_id); + evicted.extend(node.blocks); + + if evicted.len() >= n_blocks { + break; + } + } else { + // The node has more blocks than needed, so we'll just remove + // the required number of blocks and leave the remaining blocks + // untouched. + let node = self.nodes.get_mut(node_id).expect("Leave does not exist"); + node.key.truncate(node.blocks.len() - blocks_needed); + evicted.extend(node.blocks.split_off(node.blocks.len() - blocks_needed)); + self.leaves.insert((last_access, node_id)); + break; + } + } + + evicted + } + + /// Insert a prefill along with its blocks. + /// + /// This method returns the length of the prefix that was already + /// in the trie. E.g. if the length is 10, this means that for + /// the first 10 elements of the tree **the blocks are not updated**. + pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { + self.time += 1; + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) + } + + /// Insertion worker. + fn insert_( + &mut self, + node_id: NodeId, + tokens: &[u32], + blocks: &[u32], + ) -> Result { + // TODO: in the future we may want to check that the blocks match for + // the part of the prefix that is already in the trie to detect + // mismatches. + + if tokens.len() != blocks.len() * self.block_size { + return Err(TrieError::BlockTokenCountMismatch); + } + + if let Some(&child_id) = self.nodes[node_id].children.get(&tokens[0]) { + self.update_access_time(child_id); + let child = self + .nodes + .get_mut(child_id) + // Unwrap here, since failure is a bug. + .expect("Child node does not exist"); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); + + // We are done, the prefix is already in the trie. + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { + return Ok(shared_prefix_len); + } + + // The node's prefix is a prefix of the insertion prefix. + if shared_prefix_len == child.key.len() { + return Ok(shared_prefix_len + + self.insert_( + child_id, + &tokens[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], + )?); + } + + // The node's prefix and the insertion prefix only match partially, + // split the node to just contain the matching part. Then insert the + // remainder of the prefix into the node again + let child_id = self.split_node(child_id, shared_prefix_len); + let key = &tokens[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; + Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) + } else { + self.add_node(node_id, tokens, blocks); + Ok(0) + } + } + + fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId { + // We have to make the current node a child to ensure that its + // properties and node id stay the same. + + // This funcion unwraps, an invalid node_id is a programming error. + + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + let mut parent_key = node.key.split_off(prefix_len); + let mut parent_blocks = node.blocks.split_off(prefix_len); + + // Move first part of the prefix to the parent. We swap to avoid + // an allocation + copy for both splits of the key/blocks. + std::mem::swap(&mut node.key, &mut parent_key); + std::mem::swap(&mut node.blocks, &mut parent_blocks); + + let node_key = node.key[0]; + + let grandparent_id = node.parent.expect("Node does not have a parent"); + let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks); + self.add_node_to_parent(parent_id, node_key, node_id); + + // Reborrow to make the borrow checker happy. + let node = self + .nodes + .get_mut(node_id) + .expect("Node to-be split does not exist"); + node.parent = Some(parent_id); + + parent_id + } + + /// Create a node and add it to the parent. + fn add_node( + &mut self, + parent_id: NodeId, + key: impl Into>, + blocks: impl Into>, + ) -> NodeId { + let key = key.into(); + let blocks = blocks.into(); + let first = key[0]; + + let child = TrieNode::new(key, blocks, self.time, Some(parent_id)); + let child_id = self.nodes.insert(child); + + self.add_node_to_parent(parent_id, first, child_id); + self.leaves.insert((self.time, child_id)); + + child_id + } + + /// Add a node to the parent. + fn add_node_to_parent(&mut self, parent_id: NodeId, first: u32, child_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + if parent.children.insert(first, child_id).is_none() { + // Only increase reference count if child does not replace another child. + self.incref(parent_id) + .expect("Failed to increase parent refcount"); + } + } + + /// Remove a node from the trie. + fn remove_node(&mut self, node_id: NodeId) -> TrieNode { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.remove(node_id).expect("Unknown node"); + assert!( + node.children.is_empty(), + "Tried to remove a node with {} children", + node.children.len() + ); + let parent_id = node.parent.expect("Attempted to remove root node"); + let parent = self.nodes.get_mut(parent_id).expect("Unknown parent node"); + parent.children.remove(&node.key[0]); + self.decref(parent_id) + .expect("Failed to decrease parent refcount"); + node + } + + fn update_access_time(&mut self, node_id: NodeId) { + // Unwrap here, passing in an unknown id is a programming error. + let node = self.nodes.get_mut(node_id).expect("Unknown node"); + + // Update the ordered leaves set if the node is a leave. + if self.leaves.remove(&(node.last_accessed, node_id)) { + self.leaves.insert((self.time, node_id)); + } + + node.last_accessed = self.time; + } + + #[allow(dead_code)] + #[doc(hidden)] + /// Print debugging output for the trie. + /// + /// In contrast to `Debug` nicely formatted. + pub fn print_debug(&self) { + self.print_debug_(self.root, 0); + } + + fn print_debug_(&self, node_id: NodeId, indent: usize) { + let node = &self.nodes[node_id]; + eprintln!( + "{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}", + " ".repeat(indent), + node_id, + node.key, + node.blocks, + node.ref_count, + node.last_accessed, + node.parent, + node.children + ); + for child_id in self.nodes[node_id].children.values() { + self.print_debug_(*child_id, indent + 2); + } + } + + pub(crate) fn root_id(&self) -> DefaultKey { + self.root + } +} + +/// Trie node. +#[derive(Debug)] +struct TrieNode { + blocks: Vec, + children: HashMap, + key: Vec, + last_accessed: u64, + parent: Option, + ref_count: usize, +} + +impl TrieNode { + fn new(key: Vec, blocks: Vec, last_accessed: u64, parent: Option) -> Self { + TrieNode { + children: HashMap::new(), + key, + blocks, + last_accessed, + parent, + ref_count: 0, + } + } +} + +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + // NOTE: this is the case because the child node was chosen based on + // matching the first character of the key/prefix. + assert!(full > 0, "Prefixes must at least share 1 token"); + (full / block_size) * block_size +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } + + #[test] + fn allocator_reuses_prefixes() { + let mut cache = RadixAllocator::new(1, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.blocks, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_collects_older_prefixes_first() { + let mut cache = RadixAllocator::new(1, 7, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation1.prefix_len, 0); + + let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap(); + assert_eq!(allocation2.blocks, vec![1, 2]); + assert_eq!(allocation2.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // We should get the blocks of the first allocation, since they are more recent. + let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap(); + assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]); + assert_eq!(allocation3.prefix_len, 0); + } + + #[test] + fn allocator_frees_fully_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 10, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation3.prefix_len, 4); + + // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks. + assert_eq!(cache.free_blocks.len(), 5); + } + + #[test] + fn allocator_frees_partially_overlapping_prefills() { + let mut cache = RadixAllocator::new(1, 20, None); + let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap(); + assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]); + assert_eq!(allocation1.prefix_len, 0); + + cache.free(allocation1.blocks.clone(), allocation1.allocation_id); + + let allocation2 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]); + assert_eq!(allocation2.prefix_len, 2); + + let allocation3 = cache + .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation3.prefix_len, 2); + + cache.free(allocation3.blocks.clone(), allocation3.allocation_id); + cache.free(allocation2.blocks.clone(), allocation2.allocation_id); + + // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2. + assert_eq!(cache.free_blocks.len(), 11); + + let allocation4 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5]))) + .unwrap(); + assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]); + assert_eq!(allocation4.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + + let allocation5 = cache + .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7]))) + .unwrap(); + assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]); + assert_eq!(allocation5.prefix_len, 6); + assert_eq!(cache.free_blocks.len(), 11); + } + + #[test] + fn trie_insertions_have_correct_prefix_len() { + let mut trie = RadixTrie::new(1); + + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); + + // Already exists. + assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(), + 4 + ); + } + + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + + #[test] + fn trie_get_returns_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + + let mut blocks = Vec::new(); + trie.find(&[0], &mut blocks); + assert_eq!(blocks, vec![0]); + + blocks.clear(); + trie.find(&[0, 1, 2], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2]); + + blocks.clear(); + trie.find(&[1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 4]); + + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5]); + } + + #[test] + fn trie_evict_removes_correct_blocks() { + let mut trie = RadixTrie::new(1); + trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); + trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) + .unwrap(); + trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); + trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); + + let mut blocks = Vec::new(); + + // Remove less than the leave blocks. + assert_eq!(trie.evict(1), vec![7]); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]); + + // Refresh other leaf. + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + trie.find(&[1, 2, 3], &mut blocks); + + // Remove the leave blocks exactly. + assert_eq!(trie.evict(2), vec![5, 6]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks); + assert_eq!(blocks, vec![0, 1, 2, 3]); + + trie.find(&[1, 2, 3], &mut blocks); + + // Remove more than the leave blocks. + assert_eq!(trie.evict(3), vec![4, 3, 2]); + blocks.clear(); + trie.find(&[0, 1, 2, 3, 4], &mut blocks); + assert_eq!(blocks, vec![0, 1]); + + // Clear out the whole trie. + assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); + } +} diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 5e739703..789c7b51 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -148,6 +148,7 @@ async fn prefill( }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, @@ -157,6 +158,7 @@ async fn prefill( top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], slots: vec![], + prefix_len: 0, adapter_id: None, }) .collect(); diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 12966747..45301b63 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -757,7 +757,12 @@ class AsyncClient: continue payload = byte_payload.decode("utf-8") if payload.startswith("data:"): - json_payload = json.loads(payload.lstrip("data:").rstrip("\n")) + payload_data = ( + payload.lstrip("data:").rstrip("\n").removeprefix(" ") + ) + if payload_data == "[DONE]": + break + json_payload = json.loads(payload_data) try: response = ChatCompletionChunk(**json_payload) yield response diff --git a/docs/openapi.json b/docs/openapi.json index ed9b0b96..691705f2 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -556,6 +556,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -819,6 +850,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { @@ -917,7 +955,7 @@ "tool_prompt": { "type": "string", "description": "A prompt to be appended before the tools", - "example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"", + "example": "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.", "nullable": true }, "tools": { @@ -1740,6 +1778,35 @@ } ] }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, "OutputMessage": { "oneOf": [ { @@ -1817,7 +1884,8 @@ "type": "object", "required": [ "finish_reason", - "generated_tokens" + "generated_tokens", + "input_length" ], "properties": { "finish_reason": { @@ -1829,6 +1897,12 @@ "example": 1, "minimum": 0 }, + "input_length": { + "type": "integer", + "format": "int32", + "example": 1, + "minimum": 0 + }, "seed": { "type": "integer", "format": "int64", diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index e97c00aa..b883b36d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,8 +17,6 @@ title: Installation from source - local: supported_models title: Supported Models and Hardware - - local: messages_api - title: Messages API - local: architecture title: Internal Architecture - local: usage_statistics @@ -33,8 +31,6 @@ title: Serving Private & Gated Models - local: basic_tutorials/using_cli title: Using TGI CLI - - local: basic_tutorials/launcher - title: All TGI CLI options - local: basic_tutorials/non_core_models title: Non-core Model Serving - local: basic_tutorials/safety @@ -48,6 +44,14 @@ - local: basic_tutorials/train_medusa title: Train Medusa title: Tutorials +- sections: + - local: reference/launcher + title: All TGI CLI options + - local: reference/metrics + title: Exported Metrics + - local: reference/api_reference + title: API Reference + title: Reference - sections: - local: conceptual/streaming title: Streaming @@ -64,9 +68,11 @@ - local: conceptual/speculation title: Speculation (Medusa, ngram) - local: conceptual/guidance - title: How Guidance Works (via outlines + title: How Guidance Works (via outlines) - local: conceptual/lora title: LoRA (Low-Rank Adaptation) + - local: conceptual/external + title: External Resources title: Conceptual Guides diff --git a/docs/source/basic_tutorials/consuming_tgi.md b/docs/source/basic_tutorials/consuming_tgi.md index 4829ec7c..b07e7219 100644 --- a/docs/source/basic_tutorials/consuming_tgi.md +++ b/docs/source/basic_tutorials/consuming_tgi.md @@ -1,81 +1,125 @@ # Consuming Text Generation Inference -There are many ways you can consume Text Generation Inference server in your applications. After launching, you can use the `/generate` route and make a `POST` request to get results from the server. You can also use the `/generate_stream` route if you want TGI to return a stream of tokens. You can make the requests using the tool of your preference, such as curl, Python or TypeScrpt. For a final end-to-end experience, we also open-sourced ChatUI, a chat interface for open-source models. +There are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `"stream": true` to the call if you want TGI to return a stream of tokens. + +For more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference). + +You can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models. ## curl -After the launch, you can query the model using either the `/generate` or `/generate_stream` routes: +After a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec: + +```bash +curl localhost:8080/v1/chat/completions \ + -X POST \ + -d '{ + "model": "tgi", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is deep learning?" + } + ], + "stream": true, + "max_tokens": 20 +}' \ + -H 'Content-Type: application/json' +``` + +For non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes. ```bash curl 127.0.0.1:8080/generate \ -X POST \ - -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \ + -d '{ + "inputs":"What is Deep Learning?", + "parameters":{ + "max_new_tokens":20 + } +}' \ -H 'Content-Type: application/json' ``` +## Python -## Inference Client +### Inference Client -[`huggingface-hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a nice high-level class, [`~huggingface_hub.InferenceClient`], which makes it easy to make calls to a TGI endpoint. `InferenceClient` also takes care of parameter validation and provides a simple to-use interface. -You can simply install `huggingface-hub` package with pip. +[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface. + +Install `huggingface_hub` package via pip. ```bash -pip install huggingface-hub +pip install huggingface_hub ``` -Once you start the TGI server, instantiate `InferenceClient()` with the URL to the endpoint serving the model. You can then call `text_generation()` to hit the endpoint through Python. +You can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python ```python from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") -client.text_generation(prompt="Write a code for snake game") +client = InferenceClient( + base_url="http://localhost:8080/v1/", +) + +output = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, +) + +for chunk in output: + print(chunk.choices[0].delta.content) ``` -You can do streaming with `InferenceClient` by passing `stream=True`. Streaming will return tokens as they are being generated in the server. To use streaming, you can do as follows: +You can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility). + +There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + +### OpenAI Client + +You can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI. + +Install the OpenAI Python package via pip. + +```bash +pip install openai +``` ```python -for token in client.text_generation("How do you make cheese?", max_new_tokens=12, stream=True): - print(token) +from openai import OpenAI + +# init the client but point it to TGI +client = OpenAI( + base_url="http://localhost:8080/v1/", + api_key="-" +) + +chat_completion = client.chat.completions.create( + model="tgi", + messages=[ + {"role": "system", "content": "You are a helpful assistant." }, + {"role": "user", "content": "What is deep learning?"} + ], + stream=True +) + +# iterate and print stream +for message in chat_completion: + print(message) ``` -Another parameter you can use with TGI backend is `details`. You can get more details on generation (tokens, probabilities, etc.) by setting `details` to `True`. When it's specified, TGI will return a `TextGenerationResponse` or `TextGenerationStreamResponse` rather than a string or stream. +## UI -```python -output = client.text_generation(prompt="Meaning of life is", details=True) -print(output) - -# TextGenerationResponse(generated_text=' a complex concept that is not always clear to the individual. It is a concept that is not always', details=Details(finish_reason=, generated_tokens=20, seed=None, prefill=[], tokens=[Token(id=267, text=' a', logprob=-2.0723474, special=False), Token(id=11235, text=' complex', logprob=-3.1272552, special=False), Token(id=17908, text=' concept', logprob=-1.3632495, special=False),..)) -``` - -You can see how to stream below. - -```python -output = client.text_generation(prompt="Meaning of life is", stream=True, details=True) -print(next(iter(output))) - -# TextGenerationStreamResponse(token=Token(id=267, text=' a', logprob=-2.0723474, special=False), generated_text=None, details=None) -``` - -You can check out the details of the function [here](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). There is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) - - -## ChatUI - -ChatUI is an open-source interface built for LLM serving. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces. - -To serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served. - -``` -{ -// rest of the model config here -"endpoints": [{"url": "https://HOST:PORT/generate_stream"}] -} -``` - -![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png) - -## Gradio +### Gradio Gradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first. @@ -89,19 +133,28 @@ Assume you are serving your model on port 8080, we will query through [Inference import gradio as gr from huggingface_hub import InferenceClient -client = InferenceClient(model="http://127.0.0.1:8080") +client = InferenceClient(base_url="http://127.0.0.1:8080") def inference(message, history): partial_message = "" - for token in client.text_generation(message, max_new_tokens=20, stream=True): - partial_message += token + output = client.chat.completions.create( + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": message}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + partial_message += chunk.choices[0].delta.content yield partial_message gr.ChatInterface( inference, chatbot=gr.Chatbot(height=300), textbox=gr.Textbox(placeholder="Chat with me!", container=False, scale=7), - description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", + description="This is the demo for Gradio UI consuming TGI endpoint.", title="Gradio 🤝 TGI", examples=["Are tomatoes vegetables?"], retry_btn="Retry", @@ -110,20 +163,7 @@ gr.ChatInterface( ).queue().launch() ``` -The UI looks like this 👇 - -
- - -
- -You can try the demo directly here 👇 +You can check out the UI and try the demo directly here 👇