Rebase TRT-llm (#2331)

* wip

wip

refacto

refacto

Initial setup for CXX binding to TRTLLM

Working FFI call for TGI and TRTLLM backend

Remove unused parameters annd force tokenizer name to be set

Overall build TRTLLM and deps through CMake build system

Enable end to end CMake build

First version loading engines and making it ready for inference

Remembering to check how we can detect support for chunked context

Move to latest TensorRT-LLM version

Specify which default log level to use depending on CMake build type

make leader executor mode working

unconditionally call InitializeBackend on the FFI layer

bind to CUDA::nvml to retrieve compute capabilities at runtime

updated logic and comment to detect cuda compute capabilities

implement the Stream method to send new tokens through a callback

use spdlog release 1.14.1 moving forward

update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c

correctly tell cmake to build dependent tensorrt-llm required libraries

create cmake install target to put everything relevant in installation folder

add auth_token CLI argument to provide hf hub authentification token

allow converting huggingface::tokenizers error to TensorRtLlmBackendError

use correct include for spdlog

include guard to build example in cmakelists

working setup of the ffi layer

remove fmt import

use external fmt lib

end to end ffi flow working

make sure to track include/ffi.h to trigger rebuild from cargo

impl the rust backend which currently cannot move the actual computation in background thread

expose shutdown function at ffi layer

impl RwLock scenario for TensorRtLllmBackend

oops missing c++ backend definitions

compute the number of maximum new tokens for each request independently

make sure the context is not dropped in the middle of the async decoding.

remove unnecessary log

add all the necessary plumbery to return the generated content

update invalid doc in cpp file

correctly forward back the log probabilities

remove unneeded scope variable for now

refactor Stream impl for Generation to factorise code

expose the internal missing start/queue timestamp

forward tgi parameters rep/freq penalty

add some more validation about grammar not supported

define a shared struct to hold the result of a decoding step

expose information about potential error happening while decoding

remove logging

add logging in case of decoding error

make sure executor_worker is provided

add initial Dockerfile for TRTLLM backend

add some more information in CMakeLists.txt to correctly install executorWorker

add some more information in CMakeLists.txt to correctly find and install nvrtc wrapper

simplify prebuilt trtllm libraries name definition

do the same name definition stuff for tensorrt_llm_executor_static

leverage pkg-config to probe libraries paths and reuse new install structure from cmake

fix bad copy/past missing nvinfer linkage direction

align all the linker search dependency

add missing pkgconfig folder for MPI in Dockerfile

correctly setup linking search path for runtime layer

fix missing / before tgi lib path

adding missing ld_library_path for cuda stubs in Dockerfile

update tgi entrypoint

commenting out Python part for TensorRT installation

refactored docker image

move to TensorRT-LLM v0.11.0

make docker linter happy with same capitalization rule

fix typo

refactor the compute capabilities detection along with num gpus

update TensorRT-LLM to latest version

update TensorRT install script to latest

update build.rs to link to cuda 12.5

add missing dependant libraries for linking

clean up a bit

install to decoder_attention target

add some custom stuff for nccl linkage

fix envvar CARGO_CFG_TARGET_ARCH set at runtime vs compile time

use std::env::const::ARCH

make sure variable live long enough...

look for cuda 12.5

add some more basic info in README.md

* Rebase.

* Fix autodocs.

* Let's try to enable trtllm backend.

* Ignore backends/v3 by default.

* Fixing client.

* Fix makefile + autodocs.

* Updating the schema thing + redocly.

* Fix trtllm lint.

* Adding pb files ?

* Remove cargo fmt temporarily.

* ?

* Tmp.

* Remove both check + clippy  ?

* Backporting telemetry.

* Backporting 457fb0a1

* Remove PB from git.

* Fixing PB with default member backends/client

* update TensorRT-LLM to latest version

* provided None for api_key

* link against libtensorrt_llm and not libtensorrt-llm

---------

Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
Co-authored-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Nicolas Patry 2024-07-31 10:33:10 +02:00 committed by GitHub
parent 53aec27328
commit 2b19d671b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
75 changed files with 4867 additions and 1424 deletions

View File

View File

View File

@ -2,3 +2,5 @@ aml
target
server/transformers
server/flash-attention
cmake-build-debug/
cmake-build-release/

View File

@ -28,7 +28,7 @@ jobs:
- name: Install router
id: install-router
run: cargo install --path router/
run: cargo install --path backends/v3/
- uses: actions/setup-node@v4
with:
@ -41,5 +41,5 @@ jobs:
- name: Check that documentation is up-to-date
run: |
npm install -g swagger-cli
npm install -g @redocly/cli
python update_doc.py --check

4
.gitignore vendored
View File

@ -3,6 +3,10 @@ target
router/tokenizer.json
*__pycache__*
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files
*.hip
server/exllamav2_kernels/exllamav2_kernels/hip/

View File

@ -13,8 +13,8 @@ repos:
- repo: https://github.com/doublify/pre-commit-rust
rev: v1.0
hooks:
- id: fmt
- id: cargo-check
- id: fmt
- id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0

79
.redocly.lint-ignore.yaml Normal file
View File

@ -0,0 +1,79 @@
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
# See https://redoc.ly/docs/cli/ for more information.
docs/openapi.json:
no-empty-servers:
- '#/openapi'
spec:
- >-
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
- >-
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
- >-
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
- >-
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'
- '#/paths/~1/post/responses/424/content/application~1json/example'
- '#/paths/~1/post/responses/429/content/application~1json/example'
- '#/paths/~1/post/responses/500/content/application~1json/example'
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
- >-
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
- >-
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
- >-
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
- >-
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
operation-4xx-response:
- '#/paths/~1health/get/responses'
- '#/paths/~1info/get/responses'
- '#/paths/~1metrics/get/responses'
no-unused-components:
- '#/components/schemas/Completion'
security-defined:
- '#/paths/~1/post'
- '#/paths/~1generate/post'
- '#/paths/~1generate_stream/post'
- '#/paths/~1health/get'
- '#/paths/~1info/get'
- '#/paths/~1metrics/get'
- '#/paths/~1tokenize/post'
- '#/paths/~1v1~1chat~1completions/post'
- '#/paths/~1v1~1completions/post'

734
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,19 @@
[workspace]
members = [
"benchmark",
"router",
"router/client",
"router/grpc-metadata",
"launcher"
"benchmark",
"backends/v3",
"backends/grpc-metadata",
"backends/trtllm",
"backends/client",
"launcher"
]
default-members = [
"benchmark",
"backends/v3",
"backends/grpc-metadata",
# "backends/trtllm",
"backends/client",
"launcher"
]
resolver = "2"
@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
base64 = "0.22.0"
tokenizers = { version = "0.19.1", features = ["http"] }
hf-hub = { version = "0.3.1", features = ["tokio"] }
metrics = { version = "0.23.0" }
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
[profile.release]
incremental = true

View File

@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt

23
Dockerfile.trtllm Normal file
View File

@ -0,0 +1,23 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt

View File

@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo chef prepare --recipe-path recipe.json
@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher
RUN cargo build --profile release-opt

View File

@ -5,13 +5,13 @@ install-server-cpu:
cd server && make install-server
install-router:
cd router && cargo install --path .
cargo install --path backends/v3/
install-launcher:
cd launcher && cargo install --path .
cargo install --path launcher/
install-benchmark:
cd benchmark && cargo install --path .
cargo install --path benchmark/
install: install-server install-router install-launcher

View File

@ -0,0 +1,63 @@
cmake_minimum_required(VERSION 3.20)
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
include(FetchContent)
include(ExternalProject)
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies ####
include(cmake/fmt.cmake)
include(cmake/json.cmake)
include(cmake/spdlog.cmake)
include(cmake/trtllm.cmake)
# Let's build TRTLLM as part of CMake
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
# TGI TRTLLM Backend definition
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
target_include_directories(tgi_trtllm_backend_impl PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
#### Unit Tests ####
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
message(STATUS "Building tests")
FetchContent_Declare(
Catch2
GIT_REPOSITORY https://github.com/catchorg/Catch2
GIT_TAG v3.6.0
)
FetchContent_MakeAvailable(Catch2)
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
include(CTest)
include(Catch)
# catch_discover_tests(tgi_trtllm_backend_tests)
endif ()

View File

@ -0,0 +1,26 @@
[package]
name = "text-generation-backends-trtllm"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[dependencies]
async-trait = "0.1"
async-stream = "0.3"
cxx = "1.0"
text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
clap = { version = "4.5", features = ["derive"] }
thiserror = "1.0.62"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
log = { version = "0.4", features = [] }
[build-dependencies]
cmake = "0.1"
cxx-build = { version = "1.0", features = ["parallel"] }
pkg-config = "0.3"

100
backends/trtllm/Dockerfile Normal file
View File

@ -0,0 +1,100 @@
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
ARG OMPI_VERSION="4.1.6"
# Build dependencies resolver stage
FROM lukemathwalker/cargo-chef:latest AS chef
WORKDIR /usr/src/text-generation-inference
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt install -y \
build-essential \
cmake \
curl \
gcc \
g++ \
git \
git-lfs \
libssl-dev \
ninja-build \
pkg-config \
python3 \
python3-setuptools \
tar \
wget
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
# Install OpenMPI
FROM cuda-builder AS mpi-builder
ARG OMPI_VERSION
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
make -j all && \
make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
# Install TensorRT
FROM cuda-builder AS trt-builder
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
RUN chmod +x /opt/install_tensorrt.sh && \
/opt/install_tensorrt.sh
# Build Backend
FROM cuda-builder AS tgi-builder
WORKDIR /usr/src/text-generation-inference
# Install Rust
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
chmod -R a+w /root/.rustup && \
chmod -R a+w /root/.cargo
ENV PATH="/root/.cargo/bin:$PATH"
RUN cargo install cargo-chef
# Cache dependencies
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
RUN cargo chef cook --release --recipe-path recipe.json
# Build actual TGI
ARG CUDA_ARCH_LIST
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
COPY . .
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
FROM runtime
LABEL co.huggingface.vendor="Hugging Face Inc."
LABEL org.opencontainers.image.authors="hardware@hf.co"
ENTRYPOINT ["./text-generation-launcher"]
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]

46
backends/trtllm/README.md Normal file
View File

@ -0,0 +1,46 @@
# Text Generation Inference - TensorRT-LLM Backend Implementation
## Description
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
## Simplified Request Sequence
```mermaid
sequenceDiagram
actor User
participant TextGenerationInference.HttpServer
participant TextGenerationInference.TensorRtLlmBackend
participant TextGenerationInference.TensorRtLlmWorkerThread
participant TensorRtLlm.Executor
participant Nvidia.Gpu
User ->> TextGenerationInference.HttpServer: POST /generate
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
activate Nvidia.Gpu
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
rect rgb(10, 92, 54)
loop every 100us
rect rgb(15, 81, 50)
alt Acquire lock to query executor
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
else There are new generated tokens
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
rect rgb(11, 110, 79)
alt Generated token is final
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
else Generated token is not final
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
end
end
end
end
deactivate Nvidia.Gpu
end
end
```

150
backends/trtllm/build.rs Normal file
View File

@ -0,0 +1,150 @@
use cxx_build::CFG;
use pkg_config;
use std::env;
use std::env::consts::ARCH;
use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
// Dependencies
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
("dylib", "tensorrt_llm"),
("static", "tensorrt_llm_executor_static"),
("dylib", "tensorrt_llm_nvrtc_wrapper"),
("dylib", "nvinfer_plugin_tensorrt_llm"),
("dylib", "decoder_attention"),
];
macro_rules! probe {
($name: expr, $version: expr) => {
if let Err(_) = pkg_config::probe_library($name) {
pkg_config::probe_library(&format!("{}-{}", $name, $version))
.expect(&format!("Failed to locate {}", $name));
}
};
}
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
}
let _ = cmake::Config::new(".")
.uses_cxx11()
.generator("Ninja")
.profile(match is_debug {
true => "Debug",
false => "Release",
})
.env("OPT_LEVEL", opt_level)
.define("CMAKE_INSTALL_PREFIX", &install_path)
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
.build();
// Additional transitive CMake dependencies
let deps_folder = out_dir.join("build").join("_deps");
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
let dep_name = match is_debug {
true => format!("{}d", dependency),
false => String::from(dependency),
};
let dep_path = deps_folder.join(format!("{}-build", dependency));
println!("cargo:rustc-link-search={}", dep_path.display());
println!("cargo:rustc-link-lib=static={}", dep_name);
}
// Emit linkage information from the artifacts we just built
let install_lib_path = install_path.join("lib");
println!(
r"cargo:warning=Adding link search path: {}",
install_lib_path.display()
);
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf) {
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
.include(deps_folder.join("fmt-src").join("include"))
.include(deps_folder.join("spdlog-src").join("include"))
.include(deps_folder.join("json-src").join("include"))
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
.include("/usr/local/cuda/include")
.include("/usr/local/tensorrt/include")
.file("src/ffi.cpp")
.std("c++20")
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h");
println!("cargo:rerun-if-changed=src/ffi.cpp");
}
fn main() {
// Misc variables
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let build_profile = env::var("PROFILE").unwrap();
let (is_debug, opt_level) = match build_profile.as_ref() {
"debug" => (true, "0"),
_ => (false, "3"),
};
// Build the backend
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);
// Probe CUDA & co. with pkg-config
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
probe!(name, CUDA_REQUIRED_VERSION);
});
// NCCL is slightly trickier because it might not have a pkgconfig installed
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
println!("cargo:rustc-link-lib=dylib=nccl");
// TensorRT
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
println!("cargo:rustc-link-lib=dylib=nvinfer");
// TensorRT-LLM
TENSORRT_LLM_TRANSITIVE_DEPS
.iter()
.for_each(|(link_type, name)| {
println!("cargo:rustc-link-lib={}={}", link_type, name);
});
// Backend
BACKEND_DEPS.iter().for_each(|name| {
println!("cargo:rustc-link-lib=static={}", name);
});
}

View File

@ -0,0 +1,6 @@
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 11.0.1
)
FetchContent_MakeAvailable(fmt)

View File

@ -0,0 +1,5 @@
fetchcontent_declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
)
fetchcontent_makeavailable(json)

View File

@ -0,0 +1,17 @@
set(SPDLOG_USE_FMT ON)
set(SPDLOG_BUILD_SHARED OFF)
set(SPDLOG_FMT_EXTERNAL ON)
# Define the level at which SPDLOG_ compilation level is defined
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
else ()
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
endif ()
fetchcontent_declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.14.1
)
fetchcontent_makeavailable(spdlog)

View File

@ -0,0 +1,42 @@
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
set(USE_CXX11_ABI ON)
set(BUILD_PYT OFF)
set(BUILD_PYBIND OFF)
set(BUILD_MICRO_BENCHMARKS OFF)
set(BUILD_BENCHMARKS OFF)
set(BUILD_TESTS OFF)
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
set(FAST_BUILD ON)
set(NVTX_DISABLE OFF)
else ()
set(FAST_BUILD OFF)
set(FAST_MATH ON)
set(NVTX_DISABLE ON)
endif ()
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
GIT_SHALLOW FALSE
)
fetchcontent_makeavailable(trtllm)
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
CACHE INTERNAL "nvrtc wrapper library path")
# The same Executor Static library
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")

View File

@ -0,0 +1,121 @@
//
// Created by Morgan Funtowicz on 6/30/24.
//
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
#include <cmath>
#include <filesystem>
#include <span>
#include <vector>
#include <nlohmann/json.hpp>
#include <tensorrt_llm/runtime/common.h>
#include <tensorrt_llm/executor/executor.h>
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
using json = nlohmann::json;
namespace tle = tensorrt_llm::executor;
namespace huggingface::tgi::backends {
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
*/
void InitializeBackend();
/**
*
* @param config TensorRT-LLM configuration object
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
* @return
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
/**
* Get the sampling configuration from the parameters provided by TGI
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
tle::SamplingConfig GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
/**
*
*/
class TensorRtLlmBackend {
private:
const json config;
tle::Executor executor;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/**
* Query the executor for the number of token available for pulling
* @return
*/
[[nodiscard]] size_t NumResponsesReady() const;
/**
* Submit a new generation task to the executor
* @param tokens
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return Request id related to this generation for reference
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
/**
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
};
}
#endif //TGI_TRTLLM_BACKEND_H

View File

@ -0,0 +1,75 @@
//
// Created by mfuntowicz on 7/11/24.
//
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
#include <cstddef>
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
*
* @param engineFolder
* @param executorWorker
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/***
*
* @param tokens
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param seed
* @return
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/***
*
* @param requestId
* @param ctx
* @param callback
* @return
*/
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
};
/***
*
* @param engineFolder
* @return
*/
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
}
#endif //TGI_TRTLLM_BACKEND_FFI_H

View File

@ -0,0 +1,59 @@
//
// Created by mfuntowicz on 7/23/24.
//
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
#define TGI_TRTLLM_BACKEND_HARDWARE_H
#include <cstdint>
#include <limits>
#include <fmt/base.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device
*/
struct CudaComputeCapabilities {
int32_t major;
int32_t minor;
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
};
CudaComputeCapabilities GetCudaComputeCapabilities() {
// Get the compute capabilities of the current hardware
nvmlDevice_t device;
CudaComputeCapabilities capabilities{0, 0};
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
}
}
return capabilities;
}
/**
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
* @return
*/
std::optional<size_t> GetNumDevices() {
uint32_t numGpus = 0;
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
return std::optional(numGpus);
} else {
return std::nullopt;
}
}
}
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H

View File

@ -0,0 +1,146 @@
#include <fstream>
#include <fmt/ranges.h>
#include <spdlog/spdlog.h>
#include <nvml.h>
#include "backend.h"
#include "hardware.h"
void huggingface::tgi::backends::InitializeBackend() {
SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2();
initTrtLlmPlugins();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
} else {
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
}
}
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
// Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kLEADER,
std::nullopt,
std::nullopt,
std::nullopt
));
} else { // Multiple engines -> using orchestrator mode (MPI involved)
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kORCHESTRATOR,
std::nullopt,
std::nullopt,
tle::OrchestratorConfig(true, workerPath, nullptr, true)
));
}
// Define some configuration variables
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
return execConfig;
}
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
topP,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
temperature,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty
);
}
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(
enginesFolder,
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady();
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
) {
#ifdef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests
);
#else
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
#endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output});
}
[[nodiscard("Generated tokens result must be used")]]
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
}

View File

@ -0,0 +1,111 @@
#!/bin/bash
set -ex
TRT_VER="10.2.0.19"
CUDA_VER="12.5"
CUDNN_VER="9.2.1.18-1"
NCCL_VER="2.22.3-1+cuda12.5"
CUBLAS_VER="12.5.3.2-1"
NVRTC_VER="12.5.82-1"
for i in "$@"; do
case $i in
--TRT_VER=?*) TRT_VER="${i#*=}";;
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
*) ;;
esac
shift
done
NVCC_VERSION_OUTPUT=$(nvcc --version)
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
exit 1
fi
install_ubuntu_requirements() {
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
fi
if [[ $(apt list --installed | grep libnccl) ]]; then
apt-get remove --purge -y --allow-change-held-packages libnccl*
fi
if [[ $(apt list --installed | grep libcublas) ]]; then
apt-get remove --purge -y --allow-change-held-packages libcublas*
fi
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
fi
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
# NVRTC static library doesn't exist in NGC PyTorch container.
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
apt-get clean
rm -rf /var/lib/apt/lists/*
}
install_centos_requirements() {
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
yum -y update
yum -y install epel-release
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
yum clean all
}
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}
# Install base packages depending on the base OS
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
case "$ID" in
debian)
install_ubuntu_requirements
install_tensorrt
;;
ubuntu)
install_ubuntu_requirements
install_tensorrt
;;
centos)
install_centos_requirements
install_tensorrt
;;
*)
echo "Unable to determine OS..."
exit 1
;;
esac

View File

@ -0,0 +1,329 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::RwLock;
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -0,0 +1,15 @@
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -0,0 +1,84 @@
//
// Created by mfuntowicz on 6/30/24.
//
#pragma once
#include <cmath>
#include <exception>
#include <filesystem>
#include <limits>
#include <iterator>
#include <vector>
#include <spdlog/spdlog.h>
#include "backends/trtllm/include/ffi.h"
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
const std::string_view &engineFolder,
const std::string_view &executorWorker
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0];
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
}
callback(std::move(ctx), std::move(step));
}
return numTokens;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
}

View File

@ -0,0 +1,78 @@
pub use backend::{GenerationContext, TensorRtLlmBackend};
mod backend;
pub mod errors;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
pub struct GenerationStep {
token_id: u32,
log_prob: f32,
is_final: bool,
has_error: bool,
error_msg: String,
}
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
/// Represent an instance of the underlying TensorRT-LLM backend
type TensorRtLlmBackendImpl;
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
///
/// # Arguments
///
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
/// * `executor_worker`: Path to the TRTLLM executor worker
///
/// returns: <unknown>
///
/// # Examples
///
/// ```
///
/// ```
#[rust_name = "create_tensorrt_llm_backend"]
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
#[rust_name = "submit"]
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
top_k: i32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> u64;
#[rust_name = "stream_tokens"]
unsafe fn StreamTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
}
}

166
backends/trtllm/src/main.rs Normal file
View File

@ -0,0 +1,166 @@
use std::collections::HashMap;
use std::path::PathBuf;
use clap::Parser;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(long, env, required = true)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env)]
model_id: String,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
}
#[tokio::main]
async fn main() -> Result<(), TensorRtLlmBackendError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
hostname,
port,
tokenizer_name,
tokenizer_config_path,
revision,
model_id,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
} = args;
// Launch Tokio runtime
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(TensorRtLlmBackendError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if !executor_worker.exists() {
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
"`executor_work` specified path doesn't exists: {}",
executor_worker.display()
)));
}
// Run server
let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(),
Some(FromPretrainedParameters {
revision: revision.clone().unwrap_or(String::from("main")),
user_agent: HashMap::new(),
auth_token,
}),
)
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
None,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
)
.await?;
Ok(())
}

View File

@ -0,0 +1,14 @@
//
// Created by mfuntowicz on 7/2/24.
//
#include <catch2/catch_all.hpp>
#include <spdlog/spdlog.h>
#include "../include/backend.h"
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
spdlog::info("Loading config from: {}", absolute(engines).string());
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
}

66
backends/v3/Cargo.toml Normal file
View File

@ -0,0 +1,66 @@
[package]
name = "text-generation-router-v3"
description = "Text Generation Webserver"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-router = { path = "../../router" }
clap = { version = "4.4.5", features = ["derive", "env"] }
grpc-metadata = { path = "../grpc-metadata" }
futures = "0.3.28"
hf-hub = { workspace = true }
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
rand = "0.8.5"
reqwest = { version = "0.11.20", features = [] }
serde = "1.0.188"
serde_json = "1.0.107"
thiserror = "1.0.48"
tokenizers = { workspace = true}
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.14"
tower-http = { version = "0.5.1", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa = { version = "4.2.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
futures-util = "0.3.30"
regex = "1.10.3"
once_cell = "1.19.0"
image = "0.25.1"
base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
[build-dependencies]
tonic-build = "0.10.1"
prost-build = "0.12.1"
[features]
default = ["ngrok"]
ngrok = ["text-generation-router/ngrok"]
google = ["text-generation-router/google"]
kserve = ["text-generation-router/kserve"]

19
backends/v3/build.rs Normal file
View File

@ -0,0 +1,19 @@
use std::fs;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("cargo:rerun-if-changed=../../proto/");
fs::create_dir_all("src/client/pb").unwrap_or(());
let mut config = prost_build::Config::new();
config.protoc_arg("--experimental_allow_proto3_optional");
tonic_build::configure()
.build_client(true)
.build_server(false)
.out_dir("src/client/pb")
.include_file("mod.rs")
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
Ok(())
}

501
backends/v3/src/backend.rs Normal file
View File

@ -0,0 +1,501 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV3 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Client clone, used for health checks to skip the queue
client: ShardedClient,
}
impl BackendV3 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
) -> Self {
let queue = Queue::new(
requires_padding,
16,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client.clone(),
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
));
Self {
queue,
batching_task_notifier,
client,
}
}
}
#[async_trait]
impl Backend for BackendV3 {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok(UnboundedReceiverStream::new(response_rx))
}
async fn health(&self, current_health: bool) -> bool {
if current_health {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok()
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<crate::client::GeneratedText> for GeneratedText {
fn from(value: crate::client::GeneratedText) -> Self {
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
crate::client::FinishReason::Length => FinishReason::Length,
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
}
}

View File

@ -0,0 +1,284 @@
/// Single shard Client
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
/// Text Generation Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: TextGenerationServiceClient<Channel>,
}
impl Client {
/// Returns a client connected to the given url
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let channel = Channel::builder(uri).connect().await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let channel = Channel::from_shared("http://[::]:50051".to_string())
.unwrap()
.connect_with_connector(tower::service_fn(move |_: Uri| {
tokio::net::UnixStream::connect(path.clone())
}))
.await?;
Ok(Self {
stub: TextGenerationServiceClient::new(channel),
})
}
/// Returns a list of uris or unix sockets of all shards
#[instrument(skip(self))]
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
let response = self.stub.service_discovery(request).await.map_err(|_| {
ClientError::Connection("Server does not support v3 interface".to_string())
})?;
let urls = response
.into_inner()
.urls
.into_iter()
// Remove unix socket prefix
.map(|url| match url.strip_prefix("unix://") {
None => url,
Some(stripped_url) => stripped_url.to_string(),
})
.collect();
Ok(urls)
}
/// Get model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<InfoResponse> {
let request = tonic::Request::new(InfoRequest {}).inject_context();
let response = self.stub.info(request).await?.into_inner();
Ok(response)
}
/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let request = tonic::Request::new(HealthRequest {}).inject_context();
let response = self.stub.health(request).await?.into_inner();
Ok(response)
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
self.stub.clear_cache(request).await?;
Ok(())
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let request = tonic::Request::new(FilterBatchRequest {
batch_id,
request_ids,
})
.inject_context();
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
Ok(filtered_batch.batch)
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
// Safe unwrap, because we control the data.
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
mimetype: "image/jpeg;base64".to_string(),
})
.into(),
);
}
// Send stringly-typed inputs for compatibility for backends that haven't
// been updated to support chunks.
let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
inputs.push_str(&format!(
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
));
}
requests.push(Request {
id: 0,
inputs,
input_chunks: Some(Input {
chunks: input_chunks,
}),
// We truncate the input on the server side to be sure that it has the correct size
truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
top_k: 10,
top_p: 0.9,
typical_p: 0.9,
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
// Check max_batch_size
if Some(requests.len()) == max_batch_size {
break;
}
}
let batch = Batch {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_blocks: 0,
};
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -0,0 +1,76 @@
//! Text Generation gRPC client library
use async_trait::async_trait;
use thiserror::Error;
use tonic::transport;
use tonic::Status;
#[allow(clippy::derive_partial_eq_without_eq)]
mod pb;
mod grpc_client;
mod sharded_client;
pub use grpc_client::Client;
pub use pb::generate::v3::{
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
#[async_trait]
pub trait Health {
/// Check if a generate server is healthy by asking it to allocate a tensor on device
async fn device_health(&self) -> Result<()>;
/// Check if a generate server is healthy by doing a forward pass.
/// EXPENSIVE
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]
Connection(String),
#[error("Server error: {0}")]
Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
}
impl From<Status> for ClientError {
fn from(err: Status) -> Self {
let err = Self::Generation(err.message().to_string());
tracing::error!("{err}");
err
}
}
impl From<transport::Error> for ClientError {
fn from(err: transport::Error) -> Self {
let err = Self::Connection(err.to_string());
tracing::error!("{err}");
err
}
}
// Small convenience re-wrapping of `Chunk`.
impl From<Chunk> for InputChunk {
fn from(chunk: Chunk) -> Self {
InputChunk { chunk: Some(chunk) }
}
}
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
pub type Result<T> = std::result::Result<T, ClientError>;

View File

@ -0,0 +1,260 @@
use crate::client::{ClientError, Result};
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use crate::client::{Chunk, InfoResponse, Input};
use async_trait::async_trait;
use futures::future::join_all;
use tonic::transport::Uri;
use tracing::instrument;
#[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client
pub struct ShardedClient {
clients: Vec<Client>,
}
impl ShardedClient {
fn new(clients: Vec<Client>) -> Self {
Self { clients }
}
/// Create a new ShardedClient from a master client. The master client will communicate with
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
async fn from_master_client(mut master_client: Client) -> Result<Self> {
// Get all uris/unix sockets from the master client
let uris = master_client.service_discovery().await?;
let futures = uris.into_iter().map(Client::connect_uds);
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
Ok(Self::new(clients?))
}
/// Returns a client connected to the given uri
#[allow(dead_code)]
pub async fn connect(uri: Uri) -> Result<Self> {
let master_client = Client::connect(uri).await?;
Self::from_master_client(master_client).await
}
/// Returns a client connected to the given unix socket
pub async fn connect_uds(path: String) -> Result<Self> {
let master_client = Client::connect_uds(path).await?;
Self::from_master_client(master_client).await
}
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
}
/// GRPC health check
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.health())
.collect();
join_all(futures).await.pop().unwrap()
}
/// Clear the past generations cache
#[instrument(skip(self))]
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.clear_cache(batch_id))
.collect();
join_all(futures).await.into_iter().collect()
}
/// Filter a cached batch
#[instrument(skip(self))]
pub async fn filter_batch(
&mut self,
batch_id: u64,
request_ids: Vec<u64>,
) -> Result<Option<CachedBatch>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
}
/// Warmup on a max size batch
///
/// Returns the maximum amount of tokens supported by the hardware
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_size,
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
}
/// Generate one token for each request in the given batch
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
self.clone().health().await?;
Ok(())
}
async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: u64::MAX,
inputs: "liveness".to_string(),
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
adapter_id: None,
};
let batch = Batch {
id: u64::MAX,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
Ok(())
}
}

142
backends/v3/src/lib.rs Normal file
View File

@ -0,0 +1,142 @@
mod backend;
mod block_allocator;
mod client;
mod queue;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
};
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}

208
backends/v3/src/main.rs Normal file
View File

@ -0,0 +1,208 @@
use clap::{Parser, Subcommand};
use text_generation_router::server;
use text_generation_router_v3::{connect_backend, V3Error};
use thiserror::Error;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
command,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
api_key,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
disable_usage_stats,
disable_crash_reports,
max_client_batch_size,
} = args;
if let Some(Commands::PrintSchema) = command {
use utoipa::OpenApi;
let api_doc = text_generation_router::server::ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
};
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
let (backend, _backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
)
.await?;
// Run server
server::run(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
tokenizer_name,
tokenizer_config_path,
revision,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
)
.await?;
Ok(())
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("Backend failed: {0}")]
Backend(#[from] V3Error),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}

View File

@ -1,17 +1,17 @@
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
use crate::block_allocator::{BlockAllocation, BlockAllocator};
use crate::client;
use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::collections::VecDeque;
use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
ValidStoppingParameters,
};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
@ -337,8 +337,22 @@ impl State {
batch_requests.push(Request {
id,
prefill_logprobs: entry.request.decoder_input_details,
input_chunks: Some(Input {
chunks: entry.request.inputs.clone(),
input_chunks: Some(client::Input {
chunks: entry
.request
.inputs
.clone()
.into_iter()
.map(|c| client::InputChunk {
chunk: Some(match c {
Chunk::Text(text) => client::Chunk::Text(text),
Chunk::Image(image) => client::Chunk::Image(client::Image {
data: image.data,
mimetype: image.mimetype,
}),
}),
})
.collect(),
}),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate,

View File

@ -21,7 +21,7 @@ float-ord = "0.3.2"
serde = {version = "1.0.188", features = ["derive"]}
serde_json = "1.0"
tabled = "0.14.0"
text-generation-client = { path = "../router/client" }
text-generation-client = { path = "../backends/client" }
thiserror = "1.0.48"
tokenizers = { workspace = true }
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }

View File

@ -1580,16 +1580,11 @@
"type": "object",
"required": [
"model_id",
"model_dtype",
"model_device_type",
"max_concurrent_requests",
"max_best_of",
"max_stop_sequences",
"max_input_tokens",
"max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"max_client_batch_size",
"router",
@ -1601,18 +1596,6 @@
"example": "null",
"nullable": true
},
"max_batch_size": {
"type": "integer",
"example": "null",
"nullable": true,
"minimum": 0
},
"max_batch_total_tokens": {
"type": "integer",
"format": "int32",
"example": "32000",
"minimum": 0
},
"max_best_of": {
"type": "integer",
"example": "2",
@ -1644,19 +1627,6 @@
"example": "2048",
"minimum": 0
},
"max_waiting_tokens": {
"type": "integer",
"example": "20",
"minimum": 0
},
"model_device_type": {
"type": "string",
"example": "cuda"
},
"model_dtype": {
"type": "string",
"example": "torch.float16"
},
"model_id": {
"type": "string",
"description": "Model info",
@ -1690,11 +1660,6 @@
"version": {
"type": "string",
"example": "0.5.0"
},
"waiting_served_ratio": {
"type": "number",
"format": "float",
"example": "1.2"
}
}
},

View File

@ -7,25 +7,18 @@ edition.workspace = true
authors.workspace = true
homepage.workspace = true
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-router"
path = "src/main.rs"
[dependencies]
async-trait = "0.1.74"
async-stream = "0.3.5"
axum = { version = "0.7", features = ["json"] }
axum-tracing-opentelemetry = "0.16"
text-generation-client = { path = "client" }
clap = { version = "4.4.5", features = ["derive", "env"] }
futures = "0.3.28"
hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.23.0"
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
metrics = { workspace = true }
metrics-exporter-prometheus = { workspace = true }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
opentelemetry-otlp = "0.13.0"
@ -55,6 +48,7 @@ base64 = { workspace = true }
sysinfo = "0.30.13"
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
csv = "1.3.0"
ureq = "=2.9"
[build-dependencies]

View File

@ -1 +0,0 @@
*

View File

@ -1 +0,0 @@
*

View File

@ -1,528 +1,85 @@
/// Batching and inference logic
use crate::infer::v3::queue::{Entry, Queue};
use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
use crate::infer::InferError;
use crate::{
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
use text_generation_client::ClientError;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
pub(crate) struct SchedulerV3 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
/// Raise a exception (custom function) used in the chat templates
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
impl SchedulerV3 {
#[allow(clippy::too_many_arguments)]
#[derive(Clone)]
pub(crate) struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
queue.clone(),
batching_task_notifier.clone(),
generation_health,
));
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
queue,
batching_task_notifier,
template,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
}
impl Scheduler for SchedulerV3 {
#[instrument(skip_all)]
fn schedule(
pub(crate) fn apply(
&self,
request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError> {
// MPSC channel to communicate with the background batching task
let (response_tx, response_rx) = mpsc::unbounded_channel();
let input_length = request.input_length;
// Append the request to the queue
self.queue.append(Entry {
request,
response_tx,
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
});
// Notify the background task that we have a new entry in the queue that needs
// to be batched
self.batching_task_notifier.notify_one();
// Return stream
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
/// Batches requests and sends them to the inference server
#[allow(clippy::too_many_arguments)]
pub(crate) async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
queue: Queue,
notifier: Arc<Notify>,
generation_health: Arc<AtomicBool>,
) {
// Infinite loop
loop {
// Wait for a notification from the Infer struct
notifier.notified().await;
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) = queue
.next_batch(
None,
max_batch_size,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
.await;
let mut waiting_tokens = 1;
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
// Add relationships
span.follows_from(&entry_waiting_span);
entry_waiting_span.follows_from(&span);
// Update entry
entry.temp_span = Some(entry_waiting_span);
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
}
}
// Create span for this batch to add context to inference calls
let next_batch_size = entries.len();
let next_batch_span =
info_span!(parent: None, "batch", batch_size = next_batch_size);
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
next_batch_span.follows_from(&entry_batch_span);
entry_batch_span.follows_from(&next_batch_span);
// Update entry
entry.temp_span = Some(entry_batch_span);
});
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
.instrument(next_batch_span)
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
#[instrument(skip_all)]
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
// Update health
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
}
#[instrument(skip_all)]
async fn decode(
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
Err(err) => {
generation_health.store(false, Ordering::SeqCst);
for id in batch_ids {
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
}
/// Filter a `batch` and remove all requests not present in `entries`
#[instrument(skip_all)]
async fn filter_batch(
client: &mut ShardedClient,
next_batch: Option<CachedBatch>,
entries: &IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let mut batch = next_batch?;
// No need to filter
if batch.size as usize == entries.len() {
return Some(batch);
}
let id = batch.id;
// Retain only requests that are still in entries
batch.request_ids.retain(|id| entries.contains_key(id));
if batch.request_ids.is_empty() {
// All requests have been filtered out
// Next batch is now empty
// Clear it from the Python shards cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.clear_cache(Some(id)).await.unwrap();
None
} else {
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client.filter_batch(id, batch.request_ids).await.unwrap()
}
}
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
/// and filter entries
#[instrument(skip_all)]
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
generations.into_iter().for_each(|generation| {
let id = generation.request_id;
// Get entry
// We can `expect` here as the request id should always be in the entries
let entry = entries
.get(&id)
.expect("ID not found in entries. This is a bug.");
// Create and enter a span to link this function back to the entry
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
// Send generation responses back to the infer task
// If the receive an error from the Flume channel, it means that the client dropped the
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
});
}
/// Send responses through the `entry` response channel
fn send_responses(
generation: Generation,
entry: &Entry,
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
let mut stopped = false;
if let Some(prefill_tokens) = generation.prefill_tokens {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
let prefill_tokens = prefill_tokens
.ids
.into_iter()
.zip(prefill_tokens.logprobs)
.zip(prefill_tokens.texts)
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
}
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
let token = Token {
id,
text,
logprob,
special,
};
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids
.iter()
.zip(top_tokens_.logprobs.iter())
.zip(top_tokens_.texts.iter())
.zip(top_tokens_.is_special.iter())
.map(|(((&id, &logprob), text), &special)| Token {
id,
text: text.to_string(),
logprob,
special,
})
.collect()
} else {
vec![]
};
match (&generation.generated_text, iterator.peek()) {
(Some(generated_text), None) => {
// Generation has ended
stopped = true;
// Send message
entry.response_tx.send(Ok(InferStreamResponse::End {
token,
top_tokens,
generated_text: GeneratedText::from(generated_text.clone()),
queued: entry.queue_time,
start: entry.batch_time.unwrap(),
}))?;
}
_ => {
// Send message
entry
.response_tx
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
}
}
}
Ok(stopped)
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
/// Send errors to Infer for all `entries`
#[instrument(skip_all)]
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
entries.drain().for_each(|(_, entry)| {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.
entry
.response_tx
.send(Err(err))
.unwrap_or(());
});
}
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
let v3_finish_reason =
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
let finish_reason = match v3_finish_reason {
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
};
Self {
text: value.text,
generated_tokens: value.generated_tokens,
finish_reason,
seed: value.seed,
}
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
// tests
#[cfg(test)]
mod tests {
use crate::infer::raise_exception;
use crate::infer::chat_template::raise_exception;
use crate::{ChatTemplateInputs, TextMessage};
use minijinja::Environment;

View File

@ -1,34 +0,0 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::Health;
#[derive(Clone)]
pub(crate) struct HealthCheck {
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
}
impl HealthCheck {
pub(crate) fn new(
client: Arc<dyn Health + Send + Sync>,
generation_health: Arc<AtomicBool>,
) -> Self {
Self {
client,
generation_health,
}
}
pub(crate) async fn check(&mut self) -> bool {
let value = if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards can allocate on device
self.client.device_health().await
} else {
self.client.model_health().await
}
.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
}
}

View File

@ -1,23 +1,18 @@
mod health;
pub(crate) mod v2;
pub(crate) mod v3;
pub(crate) use health::HealthCheck;
// pub(crate) mod v2;
mod chat_template;
pub mod tool_grammar;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::GrammarType;
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
Message, PrefillToken, Token,
};
use async_trait::async_trait;
use chat_template::ChatTemplate;
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use minijinja::ErrorKind;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::instrument;
pub(crate) trait Scheduler {
#[async_trait]
pub trait Backend {
fn schedule(
&self,
request: ValidGenerateRequest,
permit: OwnedSemaphorePermit,
) -> Result<GenerateStreamResponse, InferError>;
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
async fn health(&self, current_health: bool) -> bool;
}
/// Inference struct
@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
pub struct Infer {
/// Validation
validation: Validation,
/// Request scheduler
scheduler: Arc<dyn Scheduler + Send + Sync>,
/// Request backend
backend: Arc<dyn Backend + Send + Sync>,
/// Chat template
chat_template: Option<ChatTemplate>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Backend health
backend_health: Arc<AtomicBool>,
}
impl Infer {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
scheduler: Arc<dyn Scheduler + Send + Sync>,
backend: impl Backend + Send + Sync + 'static,
validation: Validation,
max_concurrent_requests: usize,
tokenizer_config: HubTokenizerConfig,
@ -71,18 +70,22 @@ impl Infer {
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health
let backend_health = Arc::new(AtomicBool::new(false));
Self {
validation,
scheduler,
backend: Arc::new(backend),
chat_template,
limit_concurrent_requests: semaphore,
backend_health,
}
}
/// Add a new request to the queue and return a stream of InferStreamResponse
#[instrument(skip_all)]
pub(crate) async fn generate_stream(
&self,
pub(crate) async fn generate_stream<'a>(
&'a self,
request: GenerateRequest,
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
@ -103,7 +106,10 @@ impl Infer {
err
})?;
self.scheduler.schedule(valid_request, permit)
let input_length = valid_request.input_length;
let generation_stream = self.backend.schedule(valid_request)?;
Ok((permit, input_length, generation_stream))
}
/// Tokenizer the input
@ -155,7 +161,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -165,6 +171,8 @@ impl Infer {
let mut result_start = None;
let mut result_queued = None;
let mut stream = Box::pin(stream);
// Iterate on stream
while let Some(response) = stream.next().await {
match response? {
@ -256,207 +264,15 @@ impl Infer {
let best_response = infer_responses.remove(max_index);
Ok((best_response, infer_responses))
}
}
/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}
#[derive(Clone)]
struct ChatTemplate {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}
impl ChatTemplate {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
// check if contains the tools variable within the template
let use_default_tool_template =
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();
Self {
template,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}
fn apply(
&self,
mut messages: Vec<Message>,
grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content.push(MessageChunk::Text {
text: format!("\n---\n{}\n{}", tool_prompt, tools),
});
}
}
}
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}
pub struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
#[instrument(skip(self))]
pub(crate) async fn health(&self) -> bool {
let health = self
.backend
.health(self.backend_health.load(Ordering::SeqCst))
.await;
self.backend_health.store(health, Ordering::SeqCst);
health
}
}
@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
);
#[derive(Debug)]
pub(crate) struct GeneratedText {
pub(crate) text: String,
pub(crate) generated_tokens: u32,
pub(crate) finish_reason: FinishReason,
pub(crate) seed: Option<u64>,
pub struct GeneratedText {
pub text: String,
pub generated_tokens: u32,
pub finish_reason: FinishReason,
pub seed: Option<u64>,
}
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
pub enum InferStreamResponse {
// Optional first message
Prefill(Vec<PrefillToken>),
// Intermediate messages

View File

@ -0,0 +1,135 @@
use crate::infer::InferError;
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
pub(crate) struct ToolGrammar {}
impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}
pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};
// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);
let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();
// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};
// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);
// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();
// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);
// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();
// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();
let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};
Ok(Some(tools))
}
}

View File

@ -1,4 +1,4 @@
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV2;
pub(crate) use scheduler::BackendV2;

View File

@ -1,7 +1,7 @@
/// Batching and inference logic
use crate::infer::v2::queue::{Entry, Queue};
use crate::infer::{
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
};
use crate::validation::ValidGenerateRequest;
use crate::{FinishReason, PrefillToken, Token};
@ -18,14 +18,14 @@ use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct SchedulerV2 {
pub(crate) struct BackendV2 {
/// Request queue
queue: Queue,
/// Notify batcher on queue appends
batching_task_notifier: Arc<Notify>,
}
impl SchedulerV2 {
impl BackendV2 {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
client: ShardedClient,
@ -69,7 +69,7 @@ impl SchedulerV2 {
}
}
impl Scheduler for SchedulerV2 {
impl Backend for BackendV2 {
#[instrument(skip_all)]
fn schedule(
&self,

View File

@ -1,5 +0,0 @@
mod block_allocator;
mod queue;
mod scheduler;
pub(crate) use scheduler::SchedulerV3;

View File

@ -1,11 +1,12 @@
/// Text Generation Inference Webserver
pub mod config;
mod infer;
pub mod infer;
pub mod server;
mod validation;
pub mod validation;
#[cfg(feature = "kserve")]
mod kserve;
pub mod logging;
pub mod usage_stats;
@ -148,12 +149,13 @@ pub struct Info {
pub model_id: String,
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
pub model_sha: Option<String>,
#[schema(example = "torch.float16")]
pub model_dtype: String,
#[schema(example = "cuda")]
pub model_device_type: String,
// #[schema(example = "torch.float16")]
// pub model_dtype: String,
// #[schema(example = "cuda")]
// pub model_device_type: String,
#[schema(nullable = true, example = "text-generation")]
pub model_pipeline_tag: Option<String>,
/// Router Parameters
#[schema(example = "128")]
pub max_concurrent_requests: usize,
@ -165,18 +167,11 @@ pub struct Info {
pub max_input_tokens: usize,
#[schema(example = "2048")]
pub max_total_tokens: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "2")]
pub validation_workers: usize,
#[schema(example = "32")]
pub max_client_batch_size: usize,
/// Router Info
#[schema(example = "text-generation-router")]
pub router: &'static str,
@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
#[derive(Debug, Serialize, ToSchema)]
pub struct PrefillToken {
#[schema(example = 0)]
id: u32,
pub id: u32,
#[schema(example = "test")]
text: String,
pub text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
pub logprob: f32,
}
#[derive(Debug, Serialize, ToSchema, Clone)]
pub struct Token {
#[schema(example = 0)]
id: u32,
pub id: u32,
#[schema(example = "test")]
text: String,
pub text: String,
#[schema(nullable = true, example = - 0.34)]
logprob: f32,
pub logprob: f32,
#[schema(example = "false")]
special: bool,
pub special: bool,
}
#[derive(Debug, Serialize, ToSchema)]
@ -1102,7 +1097,7 @@ pub struct SimpleToken {
#[derive(Debug, Serialize, ToSchema)]
#[serde(rename_all(serialize = "snake_case"))]
#[schema(example = "Length")]
pub(crate) enum FinishReason {
pub enum FinishReason {
#[schema(rename = "length")]
Length,
#[serde(rename = "eos_token")]

81
router/src/logging.rs Normal file
View File

@ -0,0 +1,81 @@
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}

View File

@ -1,14 +1,13 @@
/// HTTP Server logic
use crate::config::Config;
use crate::infer::v2::SchedulerV2;
use crate::infer::v3::SchedulerV3;
use crate::infer::{HealthCheck, Scheduler};
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
use crate::infer::tool_grammar::ToolGrammar;
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
#[cfg(feature = "kserve")]
use crate::kserve::{
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::usage_stats;
use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
@ -27,7 +26,7 @@ use crate::{
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
@ -37,15 +36,18 @@ use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream;
use futures::TryStreamExt;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use http::header::AUTHORIZATION;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use text_generation_client::{v2, v3, ClientError, ShardInfo};
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal;
@ -124,12 +126,10 @@ responses(
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
)]
#[instrument(skip(health))]
#[instrument(skip(infer))]
/// Health check method
async fn health(
mut health: Extension<HealthCheck>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match health.check().await {
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
match infer.health().await {
true => Ok(()),
false => Err((
StatusCode::SERVICE_UNAVAILABLE,
@ -430,8 +430,9 @@ async fn generate_stream_internal(
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, _input_length, mut response_stream)) => {
Ok((_permit, _input_length, response_stream)) => {
let mut index = 0;
let mut response_stream = Box::pin(response_stream);
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {
index += 1;
@ -1396,262 +1397,456 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
health,
get_model_info,
compat_generate,
generate,
generate_stream,
chat_completions,
completions,
tokenize,
metrics,
),
components(
schemas(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletionLogprob,
ChatCompletionLogprobs,
ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
Usage,
DeltaToolCall,
ToolType,
Tool,
ToolCall,
Function,
FunctionDefinition,
ToolChoice,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
pub struct ApiDoc;
pub fn schema() -> ApiDoc {
ApiDoc
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
master_shard_uds_path: String,
model_info: HubModelInfo,
compat_return_full_text: bool,
backend: impl Backend + Send + Sync + 'static,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
tokenizer: Option<Tokenizer>,
config: Option<Config>,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
api_key: Option<String>,
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
preprocessor_config: Option<HubPreprocessorConfig>,
processor_config: HubProcessorConfig,
messages_api_enabled: bool,
grammar_support: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
print_schema_command: bool,
disable_usage_stats: bool,
disable_crash_reports: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
health,
get_model_info,
compat_generate,
generate,
generate_stream,
chat_completions,
completions,
tokenize,
metrics,
),
components(
schemas(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete,
ChatCompletionChoice,
ChatCompletionDelta,
ChatCompletionChunk,
ChatCompletionLogprob,
ChatCompletionLogprobs,
ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters,
PrefillToken,
Token,
GenerateResponse,
TokenizeResponse,
SimpleToken,
BestOfSequence,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
GrammarType,
Usage,
DeltaToolCall,
ToolType,
Tool,
ToolCall,
Function,
FunctionDefinition,
ToolChoice,
)
),
tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
),
info(
title = "Text Generation Inference",
license(
name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct ApiDoc;
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
Arc<dyn Scheduler + Send + Sync>,
HealthCheck,
ShardInfo,
u32,
) = {
// Helper function to check both v2 and v3
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
}
Ok(max_supported_batch_total_tokens)
}
}
};
let generation_health = Arc::new(AtomicBool::new(false));
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
Ok(mut sharded_client) => {
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V3");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
Err(_) => {
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(WebServerError::Connection)?;
// server is running on v2
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(WebServerError::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(WebServerError::Warmup)?,
)?;
let health_ext =
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
let scheduler = Arc::new(SchedulerV2::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
generation_health,
));
tracing::info!("Using scheduler V2");
(scheduler, health_ext, shard_info, max_batch_total_tokens)
}
}
} else {
Type::None
};
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
let preprocessor_config: Option<HubPreprocessorConfig> =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_config.tokenizer_class.clone(),
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
// waiting_served_ratio,
// max_batch_prefill_tokens,
// max_batch_total_tokens,
// max_waiting_tokens,
// max_batch_size,
revision.clone(),
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
let result = start(
backend,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
validation_workers,
api_key,
config,
(tokenizer, tokenizer_config),
(preprocessor_config, processor_config),
hostname,
port,
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
compat_return_full_text,
allow_origin,
)
.await;
if let Some(ua) = user_agent {
match result {
Ok(_) => {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
Ok(())
}
Err(e) => {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
Err(e)
}
}
} else {
result
}
}
#[allow(clippy::too_many_arguments)]
async fn start(
backend: impl Backend + Send + Sync + 'static,
max_concurrent_requests: usize,
max_best_of: usize,
max_stop_sequences: usize,
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
validation_workers: usize,
api_key: Option<String>,
config: Option<Config>,
(tokenizer, tokenizer_config): (Option<Tokenizer>, HubTokenizerConfig),
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
hostname: String,
port: u16,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
compat_return_full_text: bool,
allow_origin: Option<AllowOrigin>,
) -> Result<(), WebServerError> {
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Create state
let validation = Validation::new(
validation_workers,
tokenizer,
@ -1662,11 +1857,11 @@ pub async fn run(
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
grammar_support,
disable_grammar_support,
);
let infer = Infer::new(
scheduler,
backend,
validation,
max_concurrent_requests,
tokenizer_config,
@ -1703,8 +1898,8 @@ pub async fn run(
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
// Speculated tokens buckets
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
// let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
// Prometheus handler
let builder = PrometheusBuilder::new()
@ -1717,9 +1912,9 @@ pub async fn run(
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
.unwrap()
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
.unwrap()
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
.unwrap();
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
// .unwrap();
let prom_handle = builder
.install_recorder()
.expect("failed to install metrics recorder");
@ -1735,18 +1930,18 @@ pub async fn run(
let info = Info {
model_id: model_info.model_id,
model_sha: model_info.sha,
model_dtype: shard_info.dtype,
model_device_type: shard_info.device_type,
// model_dtype: shard_info.dtype,
// model_device_type: shard_info.device_type,
model_pipeline_tag: model_info.pipeline_tag,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
// waiting_served_ratio,
// max_batch_total_tokens,
// max_waiting_tokens,
// max_batch_size,
validation_workers,
max_client_batch_size,
router: env!("CARGO_PKG_NAME"),
@ -1907,7 +2102,6 @@ pub async fn run(
// add layers after routes
app = app
.layer(Extension(info))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(compute_type))
@ -1945,6 +2139,68 @@ pub async fn run(
Ok(())
}
/// get model info from the Huggingface Hub
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
@ -2008,16 +2264,77 @@ impl From<InferError> for Event {
#[derive(Debug, Error)]
pub enum WebServerError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
#[error("Axum error: {0}")]
Axum(#[from] axum::BoxError),
}
/// Create a post_processor for the LlamaTokenizer
fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}

View File

@ -78,11 +78,11 @@ pub struct Args {
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
// waiting_served_ratio: f32,
// max_batch_prefill_tokens: u32,
// max_batch_total_tokens: Option<u32>,
// max_waiting_tokens: usize,
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
@ -103,11 +103,11 @@ impl Args {
max_top_n_tokens: u32,
max_input_tokens: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
// waiting_served_ratio: f32,
// max_batch_prefill_tokens: u32,
// max_batch_total_tokens: Option<u32>,
// max_waiting_tokens: usize,
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
@ -125,11 +125,11 @@ impl Args {
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
// waiting_served_ratio,
// max_batch_prefill_tokens,
// max_batch_total_tokens,
// max_waiting_tokens,
// max_batch_size,
revision,
validation_workers,
messages_api_enabled,

View File

@ -5,13 +5,12 @@ use crate::{
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
};
use base64::{engine::general_purpose::STANDARD, Engine};
use image::{io::Reader as ImageReader, ImageFormat};
use image::{ImageFormat, ImageReader};
use jsonschema::{Draft, JSONSchema};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::io::Cursor;
use std::iter;
use text_generation_client::{Chunk, Image, InputChunk};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
use tokio::sync::mpsc;
@ -96,7 +95,7 @@ impl Validation {
&self,
inputs: String,
truncate: Option<usize>,
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
// If we have a fast tokenizer
if let Some(sender) = &self.sender {
// Create response channel
@ -122,7 +121,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
@ -181,11 +180,7 @@ impl Validation {
input_length = input_length.saturating_sub(max_new_tokens as usize);
}
Ok((
vec![Chunk::Text(inputs).into()],
input_length,
max_new_tokens,
))
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
}
}
@ -589,7 +584,7 @@ fn prepare_input(
tokenizer: &Tokenizer,
config: Option<&Config>,
preprocessor_config: Option<&HubPreprocessorConfig>,
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
@ -601,16 +596,16 @@ fn prepare_input(
let chunk_start = chunk.start();
let chunk_end = chunk.end();
if chunk_start != start {
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
tokenizer_query.push_str(&inputs[start..chunk_start]);
}
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
input_chunks.push(Chunk::Image(Image { data, mimetype }));
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
start = chunk_end;
}
if start != inputs.len() {
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
tokenizer_query.push_str(&inputs[start..]);
}
@ -618,7 +613,7 @@ fn prepare_input(
(tokenizer_query, input_chunks)
}
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
};
// Get the number of tokens in the input
@ -631,18 +626,51 @@ fn prepare_input(
type TokenizerRequest = (
(String, Option<usize>),
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
Span,
);
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Image {
pub data: Vec<u8>,
pub mimetype: String,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Chunk {
Text(String),
Image(Image),
}
/// Convert input chunks to a stringly-typed input for backwards
/// compat for backends that haven't implemented chunked inputs.
pub trait ChunksToString {
/// Convert chunks to string.
fn chunks_to_string(&self) -> String;
}
impl ChunksToString for Vec<Chunk> {
fn chunks_to_string(&self) -> String {
let mut output = String::new();
self.iter().for_each(|c| match &c {
Chunk::Text(text) => output.push_str(text),
Chunk::Image(Image { data, mimetype }) => {
let encoded = STANDARD.encode(data);
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
}
});
output
}
}
#[derive(Debug, Clone)]
pub(crate) enum ValidGrammar {
pub enum ValidGrammar {
Json(String),
Regex(String),
}
#[derive(Debug, Clone)]
pub(crate) struct ValidParameters {
pub struct ValidParameters {
/// / exponential scaling output probability distribution
pub temperature: f32,
/// / restricting to the k highest probability elements
@ -666,7 +694,7 @@ pub(crate) struct ValidParameters {
}
#[derive(Debug, Clone)]
pub(crate) struct ValidStoppingParameters {
pub struct ValidStoppingParameters {
/// / Maximum number of generated tokens
pub max_new_tokens: u32,
/// / Optional stopping sequences
@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters {
}
#[derive(Debug, Clone)]
pub(crate) struct ValidGenerateRequest {
pub inputs: Vec<InputChunk>,
pub struct ValidGenerateRequest {
pub inputs: Vec<Chunk>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
@ -750,6 +778,8 @@ pub enum ValidationError {
InvalidImageContent(String),
#[error("Could not fetch image: {0}")]
FailedFetchImage(#[from] reqwest::Error),
#[error("{0} modality is not supported")]
UnsupportedModality(&'static str),
}
#[cfg(test)]

View File

@ -167,22 +167,24 @@ def check_openapi(check: bool):
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
errors = subprocess.run(
p = subprocess.run(
[
"swagger-cli",
"redocly",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"validate",
"lint",
filename,
],
capture_output=True,
).stderr.decode("utf-8")
)
errors = p.stderr.decode("utf-8")
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
if not errors.startswith("Swagger schema validation failed."):
print(errors)
if p.returncode != 0:
print(errors)
raise Exception(
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}"
f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}"
)
return True