[TENSORRT-LLM] - Implement new looper thread based backend (#2357)
* (backend) use parking_lot crate for RwLock fairness
# Conflicts:
# backends/trtllm/src/backend.rs
* (launcher) default new server::run parameters to false for now
* (chore) fmt ... why?
* (ffi) use const for GetSamplingConfig
* (server) expose new SchedulingError
* (trt)
* (build) setup ccache if available
* (ffi) add max_new_tokens parameters
* (backend) cleanup a bit
* (backend) expose PullNewTokens
* (ffi) cleanup again
* (ffi) add missing headers imports
* (ffi) add template specialization to catch and convert to Rust Result<T, tensorrt_llm::common::TllmException>
* (looper) new looper initial implementation
* (ffi) remove narrowing type warning
* (ffi) encode the provided user prompt within each request thread
* (misc) change scope identifiers
* (backend) implement the post_processor background thread
* (misc) missing Result types for Rust
* use blocking_recv in looper to consume awaiting_requests at max before pulling in a single step
* (server) forward auth_token to server::run
* (build) fetchcontent use archives instead of git
* (ffi) fix usage of wrong vector constructor making a capacity fill call
* (ffi) missing namespace for tle::Response
* (ffi) do not use reference capture in lambda as we are not capturing anything
* (backend) refactor & cleanup
* (Dockerfile.trtllm) delete for now
* (misc) simplify [make_]move_iterator by using c++20 type inference
* (misc) no need to move for uint32_t items
* (scheduler) rework submit/pull logic
* (post) impl postprocessing
* (misc) delete backend.rs
* (misc) rerun-if-changed all the cmake modules
* (misc) move to latest trtllm
* (fix): HOPPER_SM_MAJOR is 9 not 8
* (misc: build for sm_{75,80,86,89,90} by default
* (misc): build with trtllm 0.13.0
* (misc): increase verbosity of spdlog
* (fix): do not recreate the stateful hashmap at every it
* (misc): update dependency in trtllm dockerfile
* (misc): update dependency in trtllm dockerfile
* (misc): disable logging in release mode
* (misc): improve trtllm download script robustness
* (fix): ore fixes for Dockerfile
* misc(cuda): require 12.6
* chore(cmake): use correct policy for download_timestamp
* feat(looper): check engine and executorWorker paths exist before creating the backend
* chore(cmake): download timestamp should be before URL
* feat(looper): minor optimizations to avoid growing too much the containers
* chore(trtllm): move dockerfile to right place
* chore(trtllm): disable tokenizer parallelism by default
* chore(trtllm): fmt
* chore(trtllm): post-rebase commit
* chore(trtllm): remove unused method
* feat(trtllm): cache maxNumTokens to avoid calling JSON everytime
* misc(router): remove SchedulingError
* feat(trtllm): do not tokenize twice
* Revert "chore(trtllm): remove unused method"
This reverts commit 31747163
* chore(rebase): fix invalid references
* chore(router): add python dependency
* Lint.
* Fix bad rebase
---------
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
ed87b464b4
commit
43df056eee
|
@ -2706,9 +2706,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry"
|
name = "opentelemetry"
|
||||||
version = "0.23.0"
|
version = "0.24.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76"
|
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
@ -2819,19 +2819,17 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "opentelemetry_sdk"
|
name = "opentelemetry_sdk"
|
||||||
version = "0.23.0"
|
version = "0.24.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd"
|
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"futures-channel",
|
"futures-channel",
|
||||||
"futures-executor",
|
"futures-executor",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"glob",
|
"glob",
|
||||||
"lazy_static",
|
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.23.0",
|
"opentelemetry 0.24.0",
|
||||||
"ordered-float 4.3.0",
|
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rand",
|
"rand",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
@ -4185,16 +4183,17 @@ dependencies = [
|
||||||
"cmake",
|
"cmake",
|
||||||
"cxx",
|
"cxx",
|
||||||
"cxx-build",
|
"cxx-build",
|
||||||
|
"hashbrown 0.14.5",
|
||||||
|
"hf-hub",
|
||||||
"log",
|
"log",
|
||||||
"parking_lot",
|
|
||||||
"pkg-config",
|
"pkg-config",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.19.1",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry 0.24.0",
|
"tracing-opentelemetry 0.25.0",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -4212,7 +4211,7 @@ dependencies = [
|
||||||
"tabled",
|
"tabled",
|
||||||
"text-generation-client",
|
"text-generation-client",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
@ -4292,7 +4291,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sysinfo",
|
"sysinfo",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
@ -4341,7 +4340,7 @@ dependencies = [
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
|
@ -4392,7 +4391,7 @@ dependencies = [
|
||||||
"slotmap",
|
"slotmap",
|
||||||
"text-generation-router",
|
"text-generation-router",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers 0.20.0",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tonic 0.10.2",
|
"tonic 0.10.2",
|
||||||
|
@ -4514,39 +4513,6 @@ version = "0.1.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tokenizers"
|
|
||||||
version = "0.19.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
|
|
||||||
dependencies = [
|
|
||||||
"aho-corasick",
|
|
||||||
"derive_builder",
|
|
||||||
"esaxx-rs",
|
|
||||||
"getrandom",
|
|
||||||
"hf-hub",
|
|
||||||
"indicatif",
|
|
||||||
"itertools 0.12.1",
|
|
||||||
"lazy_static",
|
|
||||||
"log",
|
|
||||||
"macro_rules_attribute",
|
|
||||||
"monostate",
|
|
||||||
"onig",
|
|
||||||
"paste",
|
|
||||||
"rand",
|
|
||||||
"rayon",
|
|
||||||
"rayon-cond",
|
|
||||||
"regex",
|
|
||||||
"regex-syntax 0.8.5",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"spm_precompiled",
|
|
||||||
"thiserror",
|
|
||||||
"unicode-normalization-alignments",
|
|
||||||
"unicode-segmentation",
|
|
||||||
"unicode_categories",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.20.0"
|
version = "0.20.0"
|
||||||
|
@ -4933,14 +4899,14 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-opentelemetry"
|
name = "tracing-opentelemetry"
|
||||||
version = "0.24.0"
|
version = "0.25.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4"
|
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"opentelemetry 0.23.0",
|
"opentelemetry 0.24.0",
|
||||||
"opentelemetry_sdk 0.23.0",
|
"opentelemetry_sdk 0.24.1",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
# All the tooling for CUDA
|
|
||||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
|
||||||
|
|
||||||
WORKDIR /usr/src/tgi/backends/trtllm
|
|
||||||
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
|
||||||
|
|
||||||
COPY . /usr/src/tgi
|
|
||||||
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
|
||||||
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
|
||||||
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
|
||||||
|
|
||||||
# All the tooling for Rust
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
|
||||||
WORKDIR /usr/src
|
|
||||||
|
|
||||||
# Include CUDA related libraries and tools to the Rust based image
|
|
||||||
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
|
||||||
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
|
||||||
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
|
||||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
|
||||||
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
|
||||||
|
|
||||||
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
|
|
@ -10,7 +10,7 @@ COPY . .
|
||||||
RUN cargo chef prepare --recipe-path recipe.json
|
RUN cargo chef prepare --recipe-path recipe.json
|
||||||
|
|
||||||
# CUDA dependent dependencies resolver stage
|
# CUDA dependent dependencies resolver stage
|
||||||
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
|
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
ninja-build \
|
ninja-build \
|
||||||
pkg-config \
|
pkg-config \
|
||||||
python3 \
|
python3 \
|
||||||
|
python3-dev \
|
||||||
python3-setuptools \
|
python3-setuptools \
|
||||||
tar \
|
tar \
|
||||||
wget
|
wget
|
||||||
|
@ -82,10 +83,15 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
|
||||||
cd backends/trtllm && \
|
cd backends/trtllm && \
|
||||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
|
||||||
|
|
||||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||||
|
RUN apt update && apt install -y python3 && \
|
||||||
|
rm -rf /var/lib/{apt,dpkg,cache,log}/
|
||||||
|
|
||||||
WORKDIR /usr/local/tgi/bin
|
WORKDIR /usr/local/tgi/bin
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||||
|
ENV TOKENIZERS_PARALLELISM=false
|
||||||
|
ENV OMPI_MCA_plm_rsh_agent=""
|
||||||
|
|
||||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
|
@ -1,5 +1,17 @@
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
|
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
|
find_program(CCACHE_EXECUTABLE "ccache")
|
||||||
|
if (CCACHE_EXECUTABLE)
|
||||||
|
message(STATUS "Using ccache")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
|
||||||
|
endif ()
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||||
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
endif ()
|
||||||
|
|
||||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
|
|
||||||
|
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
|
||||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||||
|
|
||||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||||
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||||
|
|
||||||
#### External dependencies ####
|
#### External dependencies ####
|
||||||
include(cmake/fmt.cmake)
|
include(cmake/fmt.cmake)
|
||||||
|
|
|
@ -10,16 +10,17 @@ async-trait = "0.1"
|
||||||
async-stream = "0.3"
|
async-stream = "0.3"
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
cxx = "1.0"
|
cxx = "1.0"
|
||||||
|
hashbrown = "0.14"
|
||||||
|
hf-hub = { workspace = true }
|
||||||
log = { version = "0.4", features = [] }
|
log = { version = "0.4", features = [] }
|
||||||
text-generation-router = { path = "../../router" }
|
text-generation-router = { path = "../../router" }
|
||||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
tokenizers = { workspace = true }
|
||||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.15"
|
tokio-stream = "0.1.15"
|
||||||
thiserror = "1.0.62"
|
thiserror = "1.0.63"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-opentelemetry = "0.24"
|
tracing-opentelemetry = "0.25"
|
||||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||||
parking_lot = "0.12"
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
cmake = "0.1"
|
cmake = "0.1"
|
||||||
|
|
|
@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
|
||||||
|
|
||||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||||
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
const CUDA_REQUIRED_VERSION: &str = "12.6";
|
||||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||||
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||||
|
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
||||||
// Build the backend implementation through CMake
|
// Build the backend implementation through CMake
|
||||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||||
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||||
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
|
||||||
|
|
||||||
let mut install_path = PathBuf::from(install_path);
|
let mut install_path = PathBuf::from(install_path);
|
||||||
if !install_path.is_absolute() {
|
if !install_path.is_absolute() {
|
||||||
|
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
|
||||||
(PathBuf::from(install_path), deps_folder)
|
(PathBuf::from(install_path), deps_folder)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
|
||||||
|
let ndebug = match is_debug {
|
||||||
|
true => "1",
|
||||||
|
false => "0",
|
||||||
|
};
|
||||||
|
|
||||||
CFG.include_prefix = "backends/trtllm";
|
CFG.include_prefix = "backends/trtllm";
|
||||||
cxx_build::bridge("src/lib.rs")
|
cxx_build::bridge("src/lib.rs")
|
||||||
.static_flag(true)
|
.static_flag(true)
|
||||||
|
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||||
.include("/usr/local/tensorrt/include")
|
.include("/usr/local/tensorrt/include")
|
||||||
.file("src/ffi.cpp")
|
.file("src/ffi.cpp")
|
||||||
.std("c++20")
|
.std("c++20")
|
||||||
|
.define("NDEBUG", ndebug)
|
||||||
.compile("tgi_trtllm_backend");
|
.compile("tgi_trtllm_backend");
|
||||||
|
|
||||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/json.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
|
||||||
|
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
|
||||||
println!("cargo:rerun-if-changed=include/backend.h");
|
println!("cargo:rerun-if-changed=include/backend.h");
|
||||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||||
println!("cargo:rerun-if-changed=include/ffi.h");
|
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||||
|
@ -115,7 +125,7 @@ fn main() {
|
||||||
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||||
|
|
||||||
// Build the FFI layer calling the backend above
|
// Build the FFI layer calling the backend above
|
||||||
build_ffi_layer(&deps_folder);
|
build_ffi_layer(&deps_folder, is_debug);
|
||||||
|
|
||||||
// Emit linkage search path
|
// Emit linkage search path
|
||||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
GIT_TAG 11.0.1
|
URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
|
||||||
)
|
)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
json
|
json
|
||||||
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(json)
|
fetchcontent_makeavailable(json)
|
||||||
|
|
|
@ -11,7 +11,7 @@ endif ()
|
||||||
|
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
spdlog
|
spdlog
|
||||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
GIT_TAG v1.14.1
|
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(spdlog)
|
fetchcontent_makeavailable(spdlog)
|
||||||
|
|
|
@ -23,8 +23,9 @@ endif ()
|
||||||
fetchcontent_declare(
|
fetchcontent_declare(
|
||||||
trtllm
|
trtllm
|
||||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||||
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
|
GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
|
||||||
GIT_SHALLOW FALSE
|
GIT_SHALLOW FALSE
|
||||||
|
DOWNLOAD_EXTRACT_TIMESTAMP
|
||||||
)
|
)
|
||||||
fetchcontent_makeavailable(trtllm)
|
fetchcontent_makeavailable(trtllm)
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,12 @@ namespace huggingface::tgi::backends {
|
||||||
using RequestId = tle::IdType;
|
using RequestId = tle::IdType;
|
||||||
using TokenId = tle::TokenIdType;
|
using TokenId = tle::TokenIdType;
|
||||||
|
|
||||||
|
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
|
||||||
|
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
|
||||||
|
"Submitting inference [{}] to the executor ({:d} already in-flight)");
|
||||||
|
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
|
||||||
|
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize all the components required by TRTLLM.
|
* Initialize all the components required by TRTLLM.
|
||||||
* It is required to call this function before attempting to load any engine
|
* It is required to call this function before attempting to load any engine
|
||||||
|
@ -54,7 +60,7 @@ namespace huggingface::tgi::backends {
|
||||||
float_t repetition_penalty,
|
float_t repetition_penalty,
|
||||||
float_t frequency_penalty,
|
float_t frequency_penalty,
|
||||||
uint64_t seed
|
uint64_t seed
|
||||||
);
|
) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -64,18 +70,15 @@ namespace huggingface::tgi::backends {
|
||||||
const json config;
|
const json config;
|
||||||
tle::Executor executor;
|
tle::Executor executor;
|
||||||
|
|
||||||
|
/** Frequently accessed variables cached here **/
|
||||||
|
uint32_t maxNumTokens;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit TensorRtLlmBackend(
|
explicit TensorRtLlmBackend(
|
||||||
const std::filesystem::path &engineFolder,
|
const std::filesystem::path &engineFolder,
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
|
||||||
* Indicate if the backend is ready to accept incoming request
|
|
||||||
* @return true if ready, false otherwise
|
|
||||||
*/
|
|
||||||
[[nodiscard]] bool IsReady() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Query the executor for the number of token available for pulling
|
* Query the executor for the number of token available for pulling
|
||||||
* @return
|
* @return
|
||||||
|
@ -95,25 +98,16 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] RequestId Submit(
|
[[nodiscard]] RequestId Submit(
|
||||||
const std::vector<TokenId> &tokens,
|
const std::vector<TokenId> &tokens,
|
||||||
int32_t topK,
|
const uint32_t maxNewTokens,
|
||||||
float_t topP,
|
const int32_t topK,
|
||||||
float_t temperature,
|
const float_t topP,
|
||||||
float_t repetition_penalty,
|
const float_t temperature,
|
||||||
float_t frequency_penalty,
|
const float_t repetition_penalty,
|
||||||
uint64_t seed
|
const float_t frequency_penalty,
|
||||||
|
const uint64_t seed
|
||||||
);
|
);
|
||||||
|
|
||||||
/**
|
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
|
||||||
*
|
|
||||||
* @param requestId The request id to poll the generation results
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
std::vector<tle::Response> Poll(RequestId requestId);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stop the underlying executor
|
|
||||||
*/
|
|
||||||
void Shutdown();
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,20 +5,31 @@
|
||||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
#include "backend.h"
|
#include "backend.h"
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
class TensorRtLlmBackendImpl;
|
class TensorRtLlmBackendImpl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Template to support returning error from TllmException back to Rust in a Result<>
|
||||||
|
#include <tensorrt_llm/common/tllmException.h>
|
||||||
|
|
||||||
|
namespace rust::behavior {
|
||||||
|
template<typename Try, typename Fail>
|
||||||
|
static void trycatch(Try &&func, Fail &&fail) noexcept try {
|
||||||
|
func();
|
||||||
|
} catch (tensorrt_llm::common::TllmException &e) {
|
||||||
|
fail(e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#include "backends/trtllm/src/lib.rs.h"
|
#include "backends/trtllm/src/lib.rs.h"
|
||||||
|
|
||||||
|
|
||||||
namespace huggingface::tgi::backends {
|
namespace huggingface::tgi::backends {
|
||||||
|
|
||||||
// struct GenerationContext;
|
|
||||||
|
|
||||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||||
public:
|
public:
|
||||||
/***
|
/***
|
||||||
|
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||||
|
|
||||||
/***
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
bool IsReady() const;
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param tokens
|
* @param tokens
|
||||||
|
* @param maxNewTokens
|
||||||
* @param topK
|
* @param topK
|
||||||
* @param topP
|
* @param topP
|
||||||
* @param temperature
|
* @param temperature
|
||||||
|
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
|
||||||
*/
|
*/
|
||||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||||
uint64_t
|
uint64_t
|
||||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
|
||||||
|
int32_t topK, float_t topP, float_t temperature,
|
||||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||||
|
|
||||||
/***
|
/***
|
||||||
*
|
*
|
||||||
* @param requestId
|
|
||||||
* @param ctx
|
|
||||||
* @param callback
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
size_t StreamTokens(
|
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
|
||||||
const RequestId requestId,
|
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
|
||||||
huggingface::tgi::backends::GenerationStep)> callback);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
namespace huggingface::hardware::cuda {
|
namespace huggingface::hardware::cuda {
|
||||||
|
|
||||||
#define AMPERE_SM_MAJOR 8
|
#define AMPERE_SM_MAJOR 8
|
||||||
#define HOPPER_SM_MAJOR 8
|
#define HOPPER_SM_MAJOR 9
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include <fmt/ranges.h>
|
#include <fmt/ranges.h>
|
||||||
|
@ -8,10 +9,23 @@
|
||||||
#include "hardware.h"
|
#include "hardware.h"
|
||||||
|
|
||||||
void huggingface::tgi::backends::InitializeBackend() {
|
void huggingface::tgi::backends::InitializeBackend() {
|
||||||
|
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
|
||||||
|
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
|
||||||
|
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
|
||||||
|
return std::tolower(c);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (log_level == "debug")
|
||||||
|
spdlog::set_level(spdlog::level::debug);
|
||||||
|
else
|
||||||
|
spdlog::set_level(spdlog::level::info);
|
||||||
|
}
|
||||||
|
|
||||||
SPDLOG_INFO("Initializing Backend...");
|
SPDLOG_INFO("Initializing Backend...");
|
||||||
nvmlInit_v2();
|
nvmlInit_v2();
|
||||||
initTrtLlmPlugins();
|
initTrtLlmPlugins();
|
||||||
|
|
||||||
|
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
|
||||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||||
if (numGpus.has_value()) {
|
if (numGpus.has_value()) {
|
||||||
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||||
|
@ -22,7 +36,7 @@ void huggingface::tgi::backends::InitializeBackend() {
|
||||||
|
|
||||||
[[nodiscard]]
|
[[nodiscard]]
|
||||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||||
tle::ExecutorConfig execConfig(1);
|
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
|
||||||
|
|
||||||
// Retrieve the compute capabilities to enable some options at runtime
|
// Retrieve the compute capabilities to enable some options at runtime
|
||||||
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||||
|
@ -55,12 +69,13 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
|
||||||
}
|
}
|
||||||
|
|
||||||
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||||
uint32_t topK,
|
const uint32_t topK,
|
||||||
float_t topP,
|
const float_t topP,
|
||||||
float_t temperature,
|
const float_t temperature,
|
||||||
float_t repetition_penalty,
|
const float_t repetition_penalty,
|
||||||
float_t frequency_penalty,
|
const float_t frequency_penalty,
|
||||||
uint64_t seed) {
|
const uint64_t seed) noexcept {
|
||||||
|
|
||||||
return tle::SamplingConfig(
|
return tle::SamplingConfig(
|
||||||
1, // TGI only use a single beam
|
1, // TGI only use a single beam
|
||||||
topK,
|
topK,
|
||||||
|
@ -83,26 +98,29 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||||
const std::filesystem::path &executorWorker
|
const std::filesystem::path &executorWorker
|
||||||
) :
|
) :
|
||||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||||
executor(
|
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||||
enginesFolder,
|
GetExecutorConfig(config, executorWorker.string())) {
|
||||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
|
||||||
GetExecutorConfig(config, executorWorker.string()
|
|
||||||
)) {
|
|
||||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||||
}
|
|
||||||
|
|
||||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
// Cache variables
|
||||||
return executor.canEnqueueRequests();
|
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||||
return executor.getNumResponsesReady();
|
const auto numResponses = executor.getNumResponsesReady();
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
if(numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return numResponses;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const std::vector<tle::TokenIdType> &tokens,
|
const std::vector<tle::TokenIdType> &tokens,
|
||||||
|
const uint32_t maxNewTokens,
|
||||||
const int32_t topK,
|
const int32_t topK,
|
||||||
const float_t topP,
|
const float_t topP,
|
||||||
const float_t temperature,
|
const float_t temperature,
|
||||||
|
@ -110,37 +128,23 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||||
const float_t frequency_penalty,
|
const float_t frequency_penalty,
|
||||||
const uint64_t seed
|
const uint64_t seed
|
||||||
) {
|
) {
|
||||||
#ifdef NDEBUG
|
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
|
||||||
SPDLOG_DEBUG(
|
#ifndef NDEBUG
|
||||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
{
|
||||||
tokens.size(),
|
const auto &iterations = executor.getLatestIterationStats();
|
||||||
executor.getLatestIterationStats().back().numActiveRequests
|
const auto &lastIteration = iterations.front();
|
||||||
);
|
|
||||||
#else
|
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
|
||||||
SPDLOG_DEBUG(
|
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
|
||||||
fmt::join(tokens, ", "),
|
}
|
||||||
executor.getLatestIterationStats().front().numActiveRequests
|
|
||||||
);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
|
||||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
|
||||||
|
|
||||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
|
||||||
return executor.enqueueRequest(
|
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
|
||||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard("Generated tokens result must be used")]]
|
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
|
||||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
return executor.awaitResponses();
|
||||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
|
||||||
return executor.awaitResponses(requestId);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
|
||||||
SPDLOG_INFO("Shutting down executor");
|
|
||||||
executor.shutdown();
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,13 @@
|
||||||
|
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
TRT_VER="10.2.0.19"
|
TRT_VER_BASE="10.4.0"
|
||||||
CUDA_VER="12.5"
|
TRT_VER_FULL="${TRT_VER_BASE}.26"
|
||||||
CUDNN_VER="9.2.1.18-1"
|
CUDA_VER="12.6"
|
||||||
NCCL_VER="2.22.3-1+cuda12.5"
|
CUDNN_VER="9.5.0.50-1"
|
||||||
CUBLAS_VER="12.5.3.2-1"
|
NCCL_VER="2.22.3-1+cuda12.6"
|
||||||
NVRTC_VER="12.5.82-1"
|
CUBLAS_VER="12.6.3.3-1"
|
||||||
|
NVRTC_VER="12.6.77-1"
|
||||||
|
|
||||||
for i in "$@"; do
|
for i in "$@"; do
|
||||||
case $i in
|
case $i in
|
||||||
|
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
|
||||||
ARCH=$(uname -m)
|
ARCH=$(uname -m)
|
||||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
|
||||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||||
|
@ -71,7 +73,7 @@ install_centos_requirements() {
|
||||||
install_tensorrt() {
|
install_tensorrt() {
|
||||||
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||||
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||||
TRT_CUDA_VERSION="12.5"
|
TRT_CUDA_VERSION="12.6"
|
||||||
|
|
||||||
if [ -z "$RELEASE_URL_TRT" ];then
|
if [ -z "$RELEASE_URL_TRT" ];then
|
||||||
ARCH=${TRT_TARGETARCH}
|
ARCH=${TRT_TARGETARCH}
|
||||||
|
@ -79,12 +81,12 @@ install_tensorrt() {
|
||||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||||
fi
|
fi
|
||||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
|
||||||
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||||
rm -rf /tmp/TensorRT.tar
|
rm -rf /tmp/TensorRT.tar
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,330 +0,0 @@
|
||||||
use std::future::Future;
|
|
||||||
use std::path::Path;
|
|
||||||
use std::pin::{pin, Pin};
|
|
||||||
use std::str::FromStr;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::{Arc, OnceLock};
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use cxx::UniquePtr;
|
|
||||||
use log::{error, warn};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
|
||||||
use tokio::time::{sleep, Instant};
|
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
||||||
use tokio_stream::{Stream, StreamExt};
|
|
||||||
use tracing::{instrument, span, Level};
|
|
||||||
|
|
||||||
// use tokio::sync::RwLock;
|
|
||||||
use parking_lot::RwLock;
|
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
|
||||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
|
||||||
|
|
||||||
// Value used to poll the state of the generation stream
|
|
||||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
|
||||||
|
|
||||||
type InferResult<T> = Result<T, InferError>;
|
|
||||||
|
|
||||||
pub(crate) struct Generation {
|
|
||||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
|
||||||
done: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Holds the user provided input to be executed along with a channel allowing
|
|
||||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
|
||||||
pub struct GenerationContext {
|
|
||||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
done: Arc<AtomicBool>,
|
|
||||||
queued: Instant,
|
|
||||||
start: Option<Instant>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Stream for Generation {
|
|
||||||
type Item = usize;
|
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
|
||||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
|
||||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
|
||||||
});
|
|
||||||
|
|
||||||
if !self.done.load(Ordering::Relaxed) {
|
|
||||||
let backend = pin!(self.executor.read());
|
|
||||||
let status = match backend.poll(ctx) {
|
|
||||||
Poll::Ready(executor_r) => {
|
|
||||||
let ready = executor_r.num_responses_ready();
|
|
||||||
if ready == 0 {
|
|
||||||
Poll::Pending
|
|
||||||
} else {
|
|
||||||
Poll::Ready(Some(ready))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
};
|
|
||||||
|
|
||||||
let waker = ctx.waker().clone();
|
|
||||||
tokio::spawn(async {
|
|
||||||
sleep(Duration::from_micros(*interval)).await;
|
|
||||||
waker.wake();
|
|
||||||
});
|
|
||||||
|
|
||||||
status
|
|
||||||
} else {
|
|
||||||
Poll::Ready(None) // end of stream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
||||||
(1, None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
|
||||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
|
||||||
|
|
||||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
|
||||||
pub struct TensorRtLlmBackend {
|
|
||||||
tokenizer: Arc<Tokenizer>,
|
|
||||||
|
|
||||||
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
|
||||||
// the number of available tokens (read only) in the Generation stream
|
|
||||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TensorRtLlmBackend {
|
|
||||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
engine_folder: P,
|
|
||||||
executor_worker_path: PP,
|
|
||||||
) -> Result<Self, TensorRtLlmBackendError> {
|
|
||||||
Ok(TensorRtLlmBackend {
|
|
||||||
tokenizer: Arc::new(tokenizer),
|
|
||||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
|
||||||
engine_folder.as_ref().to_str().unwrap(),
|
|
||||||
executor_worker_path.as_ref().to_str().unwrap(),
|
|
||||||
))),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
|
||||||
if request.top_n_tokens > 1 {
|
|
||||||
return Err(InferError::ValidationError(
|
|
||||||
ValidationError::TopNTokensDisabled,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Is it really needed? How can it be validated before?
|
|
||||||
if request.parameters.grammar.is_some() {
|
|
||||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
|
||||||
}
|
|
||||||
|
|
||||||
match request.inputs.len() {
|
|
||||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
|
||||||
2.. => Err(InferError::GenerationError(
|
|
||||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
|
||||||
)),
|
|
||||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
|
||||||
Chunk::Text(text) => Ok(text),
|
|
||||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn generate(
|
|
||||||
&self,
|
|
||||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
top_k: u32,
|
|
||||||
top_p: f32,
|
|
||||||
temperature: f32,
|
|
||||||
repetition_penalty: f32,
|
|
||||||
frequency_penalty: f32,
|
|
||||||
seed: u64,
|
|
||||||
) {
|
|
||||||
let tokenizer = Arc::clone(&self.tokenizer);
|
|
||||||
let executor = Arc::clone(&self.backend);
|
|
||||||
|
|
||||||
// Let's push this in async context
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Define the generation state
|
|
||||||
let mut generation = Generation {
|
|
||||||
executor: executor.clone(),
|
|
||||||
done: Arc::new(AtomicBool::new(false)),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Define the context over the generation
|
|
||||||
// TODO(asap): Do we really need so many shared-ownership?
|
|
||||||
let ctx = Box::new(GenerationContext {
|
|
||||||
sender: sender.clone(),
|
|
||||||
tokenizer,
|
|
||||||
tokens: vec![],
|
|
||||||
done: Arc::clone(&generation.done),
|
|
||||||
start: None,
|
|
||||||
queued: Instant::now(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
|
||||||
// still computation ongoing
|
|
||||||
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
|
||||||
let ctx_ = Box::leak(ctx);
|
|
||||||
|
|
||||||
// Submit the request to the batcher
|
|
||||||
let request_id = span!(Level::DEBUG, "submit")
|
|
||||||
.in_scope(|| async {
|
|
||||||
let mut handle = executor.write().await;
|
|
||||||
let request_id = handle.pin_mut().submit(
|
|
||||||
&tokens,
|
|
||||||
top_k as i32,
|
|
||||||
top_p,
|
|
||||||
temperature,
|
|
||||||
repetition_penalty,
|
|
||||||
frequency_penalty,
|
|
||||||
seed,
|
|
||||||
);
|
|
||||||
|
|
||||||
request_id
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
while let Some(_) = generation.next().await {
|
|
||||||
let mut executor_w = executor.write().await;
|
|
||||||
let executor = executor_w.pin_mut();
|
|
||||||
|
|
||||||
span!(Level::DEBUG, "decode")
|
|
||||||
.in_scope(|| async {
|
|
||||||
unsafe {
|
|
||||||
executor.stream_tokens(
|
|
||||||
request_id,
|
|
||||||
ctx_,
|
|
||||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
|
||||||
let inner_ctx = &mut *ctx;
|
|
||||||
|
|
||||||
// Update the timestamp at which the request started effectively
|
|
||||||
// Can be a bit off, would need to be before the callback, let's see
|
|
||||||
inner_ctx.start.get_or_insert(Instant::now());
|
|
||||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
|
||||||
|
|
||||||
// Ensure we are not running into errors
|
|
||||||
let parcel = if !step.has_error {
|
|
||||||
// Insert the latest generated token to the tracker
|
|
||||||
inner_ctx.tokens.push(step.token_id);
|
|
||||||
|
|
||||||
// Decode the token
|
|
||||||
let text = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.decode(&[step.token_id], true)
|
|
||||||
.expect("Failed to decode token");
|
|
||||||
|
|
||||||
let special = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.get_added_vocabulary()
|
|
||||||
.is_special_token(&text);
|
|
||||||
|
|
||||||
// Create the structure holding the token
|
|
||||||
let token = Token {
|
|
||||||
id: step.token_id,
|
|
||||||
text,
|
|
||||||
logprob: step.log_prob,
|
|
||||||
special,
|
|
||||||
};
|
|
||||||
|
|
||||||
if step.is_final {
|
|
||||||
let generated_text = inner_ctx
|
|
||||||
.tokenizer
|
|
||||||
.decode(&inner_ctx.tokens, true)
|
|
||||||
.expect("Failed to decode generated_tokens");
|
|
||||||
|
|
||||||
Ok(InferStreamResponse::End {
|
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
generated_text: GeneratedText {
|
|
||||||
text: generated_text,
|
|
||||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
|
||||||
finish_reason: FinishReason::EndOfSequenceToken,
|
|
||||||
seed: None,
|
|
||||||
},
|
|
||||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
|
||||||
queued: inner_ctx.queued,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Ok(InferStreamResponse::Intermediate {
|
|
||||||
token,
|
|
||||||
top_tokens: vec![],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
error!("Error caught while decoding: {}", &step.error_msg);
|
|
||||||
Err(InferError::GenerationError(step.error_msg))
|
|
||||||
};
|
|
||||||
|
|
||||||
// Send the parcel to the client
|
|
||||||
inner_ctx
|
|
||||||
.sender
|
|
||||||
.send(parcel)
|
|
||||||
.expect("Failed to sent msg through the channel");
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// "Properly" free the shared context...
|
|
||||||
// TODO: clean that piece of sh** asap
|
|
||||||
unsafe {
|
|
||||||
let _ = Box::from_raw(ctx_);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Backend for TensorRtLlmBackend {
|
|
||||||
#[instrument(skip_all)]
|
|
||||||
fn schedule(
|
|
||||||
&self,
|
|
||||||
request: ValidGenerateRequest,
|
|
||||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
|
||||||
// Let's add a few more validation
|
|
||||||
let input = TensorRtLlmBackend::validate(&request)?;
|
|
||||||
|
|
||||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
|
||||||
let (sender, receiver) = unbounded_channel();
|
|
||||||
|
|
||||||
// Unpack parameters
|
|
||||||
let params = &request.parameters;
|
|
||||||
|
|
||||||
// Preprocess the inputs to send to TRTLLM backend
|
|
||||||
let encoding = self
|
|
||||||
.tokenizer
|
|
||||||
.encode(input.as_str(), true)
|
|
||||||
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Generate the response
|
|
||||||
self.generate(
|
|
||||||
sender,
|
|
||||||
Vec::from(encoding.get_ids()),
|
|
||||||
params.top_k,
|
|
||||||
params.top_p,
|
|
||||||
params.temperature,
|
|
||||||
params.repetition_penalty,
|
|
||||||
params.frequency_penalty,
|
|
||||||
params.seed,
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(UnboundedReceiverStream::new(receiver))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health(&self, _current_health: bool) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,9 +1,16 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum TensorRtLlmBackendError {
|
pub enum TensorRtLlmBackendError {
|
||||||
|
#[error("Provided engine folder {0} doesn't exist")]
|
||||||
|
EngineFolderDoesntExists(PathBuf),
|
||||||
|
#[error("Provided executorWorker binary path {0} doesn't exist")]
|
||||||
|
ExecutorWorkerNotFound(PathBuf),
|
||||||
|
#[error("TensorRT-LLM Runtime error: {0}")]
|
||||||
|
Runtime(String),
|
||||||
#[error("Tokenizer error: {0}")]
|
#[error("Tokenizer error: {0}")]
|
||||||
Tokenizer(String),
|
Tokenizer(String),
|
||||||
#[error("Argument validation error: {0}")]
|
#[error("Argument validation error: {0}")]
|
||||||
|
|
|
@ -3,11 +3,13 @@
|
||||||
//
|
//
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cmath>
|
#include <algorithm>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <functional>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <ranges>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <spdlog/spdlog.h>
|
#include <spdlog/spdlog.h>
|
||||||
|
@ -20,61 +22,59 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||||
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||||
|
|
||||||
|
|
||||||
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
|
||||||
return TensorRtLlmBackend::IsReady();
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
|
||||||
float_t frequency_penalty, uint64_t seed) {
|
int32_t topK, float_t topP, float_t temperature,
|
||||||
|
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
|
||||||
|
|
||||||
// This will copy all the items from the initial slice
|
// This will copy all the items from the initial slice
|
||||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
|
||||||
return TensorRtLlmBackend::Submit(
|
return TensorRtLlmBackend::Submit(
|
||||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
|
||||||
const uint64_t requestId,
|
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
const auto responses = TensorRtLlmBackend::PullNewTokens();
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
|
||||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
|
||||||
|
|
||||||
size_t numTokens = 0;
|
auto steps = std::make_unique<std::vector<GenerationStep>>();
|
||||||
for (const auto &item: Poll(requestId)) {
|
steps->reserve(responses.size());
|
||||||
GenerationStep step;
|
|
||||||
if (!item.hasError()) {
|
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
|
||||||
const auto decoded = item.getResult();
|
|
||||||
|
|
||||||
const auto token = decoded.outputTokenIds[0][0];
|
#ifndef NDEBUG
|
||||||
const auto isFinal = decoded.isFinal;
|
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
|
||||||
const auto logProb = decoded.logProbs.value()[0][0];
|
#endif
|
||||||
|
|
||||||
++numTokens;
|
// Transform tle::Response to GenerationStep
|
||||||
|
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
|
||||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
const auto reqId = r.getRequestId();
|
||||||
step = huggingface::tgi::backends::GenerationStep{
|
if (!r.hasError()) {
|
||||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
const auto result = r.getResult();
|
||||||
|
return GenerationStep{
|
||||||
|
reqId,
|
||||||
|
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||||
|
result.logProbs.value()[0][0],
|
||||||
|
result.isFinal,
|
||||||
|
false,
|
||||||
|
std::string()
|
||||||
};
|
};
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
|
||||||
} else {
|
} else {
|
||||||
// TODO : Return rest::Result with error
|
return GenerationStep{
|
||||||
const auto what = item.getErrorMsg();
|
reqId,
|
||||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
0,
|
||||||
step = huggingface::tgi::backends::GenerationStep{
|
0.0,
|
||||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
true,
|
||||||
|
true,
|
||||||
|
std::move(r.getErrorMsg())
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
callback(std::move(ctx), std::move(step));
|
return steps;
|
||||||
}
|
|
||||||
|
|
||||||
return numTokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||||
|
SPDLOG_INFO("Creating TensorRT-LLM Backend");
|
||||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||||
InitializeBackend();
|
InitializeBackend();
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
pub use looper::TensorRtLlmBackendV2;
|
||||||
|
|
||||||
mod backend;
|
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
mod looper;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
||||||
/// Struct used as shared type between rust and C++ to represent the result
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
/// of a single decoding iteration
|
/// of a single decoding iteration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct GenerationStep {
|
pub struct GenerationStep {
|
||||||
|
request_id: u64,
|
||||||
token_id: u32,
|
token_id: u32,
|
||||||
log_prob: f32,
|
log_prob: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
|
@ -16,10 +18,6 @@ mod ffi {
|
||||||
error_msg: String,
|
error_msg: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "Rust" {
|
|
||||||
type GenerationContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe extern "C++" {
|
unsafe extern "C++" {
|
||||||
include!("backends/trtllm/src/ffi.cpp");
|
include!("backends/trtllm/src/ffi.cpp");
|
||||||
|
|
||||||
|
@ -44,10 +42,7 @@ mod ffi {
|
||||||
fn CreateTensorRtLlmBackend(
|
fn CreateTensorRtLlmBackend(
|
||||||
engine_folder: &str,
|
engine_folder: &str,
|
||||||
executor_worker: &str,
|
executor_worker: &str,
|
||||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
|
||||||
|
|
||||||
// #[rust_name = "is_ready"]
|
|
||||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
|
||||||
|
|
||||||
#[rust_name = "num_responses_ready"]
|
#[rust_name = "num_responses_ready"]
|
||||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||||
|
@ -56,23 +51,18 @@ mod ffi {
|
||||||
fn Submit(
|
fn Submit(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
tokens: &[u32],
|
tokens: &[u32],
|
||||||
|
max_new_tokens: u32,
|
||||||
top_k: i32,
|
top_k: i32,
|
||||||
top_p: f32,
|
top_p: f32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
repetition_penalty: f32,
|
repetition_penalty: f32,
|
||||||
frequency_penalty: f32,
|
frequency_penalty: f32,
|
||||||
seed: u64,
|
seed: u64,
|
||||||
) -> u64;
|
) -> Result<u64>;
|
||||||
|
|
||||||
#[rust_name = "stream_tokens"]
|
#[rust_name = "pull_tokens"]
|
||||||
unsafe fn StreamTokens(
|
fn PullTokens(
|
||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
request_id: u64,
|
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
|
||||||
ctx: *mut GenerationContext,
|
|
||||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
|
||||||
) -> usize;
|
|
||||||
|
|
||||||
// #[rust_name = "shutdown"]
|
|
||||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,395 @@
|
||||||
|
use std::hint;
|
||||||
|
use std::ops::Deref;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use cxx::UniquePtr;
|
||||||
|
use hashbrown::HashMap;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||||
|
use tokio::sync::TryAcquireError;
|
||||||
|
use tokio::task::{spawn_blocking, JoinHandle};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
|
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
|
||||||
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
|
use text_generation_router::validation::ValidationError::{
|
||||||
|
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
|
||||||
|
};
|
||||||
|
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
|
||||||
|
use text_generation_router::{FinishReason, Token};
|
||||||
|
|
||||||
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
|
use crate::utils::first_line;
|
||||||
|
|
||||||
|
type InferResult<T> = Result<T, InferError>;
|
||||||
|
|
||||||
|
struct IdentifiableRequest<T> {
|
||||||
|
request_id: u64,
|
||||||
|
inner: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
|
||||||
|
struct GenerationContext {
|
||||||
|
request: ValidGenerateRequest,
|
||||||
|
start: Option<Instant>,
|
||||||
|
queued: Instant,
|
||||||
|
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
struct DecodedToken {
|
||||||
|
id: u32,
|
||||||
|
log_prob: f32,
|
||||||
|
is_final: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||||
|
type Error = InferError;
|
||||||
|
|
||||||
|
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
||||||
|
if !step.has_error {
|
||||||
|
Ok(Self {
|
||||||
|
id: step.token_id,
|
||||||
|
log_prob: step.log_prob,
|
||||||
|
is_final: step.is_final,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(GenerationError(step.error_msg.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
|
||||||
|
struct DecodedTokenContext {
|
||||||
|
token: DecodedToken,
|
||||||
|
start: Option<Instant>,
|
||||||
|
queued: Instant,
|
||||||
|
channel: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn executor_status_looper(
|
||||||
|
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
mut waiting_requests: UnboundedReceiver<GenerationContext>,
|
||||||
|
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
|
||||||
|
) {
|
||||||
|
// Track the tuple (request_id, stream) for each request
|
||||||
|
let mut in_flights =
|
||||||
|
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
|
||||||
|
|
||||||
|
// TODO: Does it need a spin-loop?
|
||||||
|
'scheduler: loop {
|
||||||
|
// Is there any request pending to be scheduled?
|
||||||
|
let awaiting_requests = waiting_requests.len();
|
||||||
|
for _ in 0..awaiting_requests {
|
||||||
|
// Retrieve all the requests
|
||||||
|
if let Some(mut ctx) = waiting_requests.blocking_recv() {
|
||||||
|
// Submit all the request to the executor and move the context to the in-flight tracker
|
||||||
|
let request = &ctx.request;
|
||||||
|
let generation_params = &request.parameters;
|
||||||
|
let stopping_params = &request.stopping_parameters;
|
||||||
|
let input_ids = request.input_ids.as_deref();
|
||||||
|
|
||||||
|
// Submit to the TensorRT-LLM executor for scheduling
|
||||||
|
match backend.pin_mut().submit(
|
||||||
|
&input_ids.unwrap(), // This is checked beforehand in validate()
|
||||||
|
stopping_params.max_new_tokens,
|
||||||
|
generation_params.top_k as i32,
|
||||||
|
generation_params.top_p,
|
||||||
|
generation_params.temperature,
|
||||||
|
generation_params.repetition_penalty,
|
||||||
|
generation_params.frequency_penalty,
|
||||||
|
generation_params.seed,
|
||||||
|
) {
|
||||||
|
Ok(request_id) => {
|
||||||
|
// Insert the context linked to the generated request id in the tracker
|
||||||
|
debug!("[in-flight] Added {}", request_id);
|
||||||
|
ctx.start = Some(Instant::now());
|
||||||
|
in_flights.insert(request_id, ctx);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Return to the caller
|
||||||
|
let what = e.to_string();
|
||||||
|
error!(error = what.as_str(), "Failed to schedule request");
|
||||||
|
|
||||||
|
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
|
||||||
|
if let Err(_) = ctx.streamer.send(err) {
|
||||||
|
error!("Failed to send back error to the client");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if backend.num_responses_ready() > 0 {
|
||||||
|
match backend.pin_mut().pull_tokens() {
|
||||||
|
Ok(responses) => {
|
||||||
|
// Iterate through all the decoded token
|
||||||
|
for step in responses.deref() {
|
||||||
|
if let Some(ctx) = in_flights.get(&step.request_id) {
|
||||||
|
// Remove from tracked requests
|
||||||
|
let parcel =
|
||||||
|
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
|
||||||
|
token: dt,
|
||||||
|
start: ctx.start,
|
||||||
|
queued: ctx.queued,
|
||||||
|
channel: ctx.streamer.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Submit the work to p:the post_processor
|
||||||
|
let posted = post_processor_sender.send((step.request_id, parcel));
|
||||||
|
|
||||||
|
if posted.is_err() || step.is_final {
|
||||||
|
debug!("Removing {}", step.request_id);
|
||||||
|
let _ = in_flights.remove(&step.request_id);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!("Untracked request {}", step.request_id,);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ref err) => {
|
||||||
|
error!("Failed to get responses from the executor: {}.", err.what());
|
||||||
|
break 'scheduler;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hint the CPU we are spin-locking
|
||||||
|
hint::spin_loop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn post_processor_looper(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
max_num_tokens: usize,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
|
||||||
|
) {
|
||||||
|
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
|
||||||
|
|
||||||
|
'post_processor: loop {
|
||||||
|
if decoded_tokens.is_closed() {
|
||||||
|
warn!("Post processor IPC is closed, loop will exit now.");
|
||||||
|
break 'post_processor;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
|
||||||
|
match decoded {
|
||||||
|
Ok(ctx) => {
|
||||||
|
states
|
||||||
|
.entry(request_id)
|
||||||
|
.and_modify(|s| s.push(*&ctx.token.id))
|
||||||
|
.or_insert_with(|| {
|
||||||
|
let mut state = Vec::with_capacity(max_num_tokens);
|
||||||
|
state.push(*&ctx.token.id);
|
||||||
|
state
|
||||||
|
});
|
||||||
|
|
||||||
|
let out = match tokenizer.decode(&[ctx.token.id], false) {
|
||||||
|
Ok(text) => {
|
||||||
|
let is_special =
|
||||||
|
tokenizer.get_added_vocabulary().is_special_token(&text);
|
||||||
|
let token = Token {
|
||||||
|
id: ctx.token.id,
|
||||||
|
text,
|
||||||
|
logprob: ctx.token.log_prob,
|
||||||
|
special: is_special,
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = if !ctx.token.is_final {
|
||||||
|
InferStreamResponse::Intermediate {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let tokens = states.remove(&request_id).unwrap();
|
||||||
|
let text = tokenizer.decode(&tokens, true);
|
||||||
|
let generated_text = GeneratedText {
|
||||||
|
text: text.unwrap(),
|
||||||
|
generated_tokens: tokens.len() as u32,
|
||||||
|
finish_reason: FinishReason::EndOfSequenceToken,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
InferStreamResponse::End {
|
||||||
|
token,
|
||||||
|
top_tokens: vec![],
|
||||||
|
generated_text,
|
||||||
|
start: ctx.start.unwrap(),
|
||||||
|
queued: ctx.queued,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
Err(err) => Err(GenerationError(err.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(_) = ctx.channel.send(out) {
|
||||||
|
warn!("Failed to send decoded token back to the user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_err) => {
|
||||||
|
todo!("what do we do?")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
|
||||||
|
engine_folder: P,
|
||||||
|
executor_worker_path: PP,
|
||||||
|
) -> Result<(String, String), TensorRtLlmBackendError> {
|
||||||
|
// Retrieve paths as &str for the backend creation
|
||||||
|
let engine_folder = engine_folder.as_ref();
|
||||||
|
let executor_worker_path = executor_worker_path.as_ref();
|
||||||
|
|
||||||
|
// Ensure the engine folder exists
|
||||||
|
if !engine_folder.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure executor worker binary exists
|
||||||
|
if !executor_worker_path.exists() {
|
||||||
|
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
|
||||||
|
|
||||||
|
error!("Path validation failed: {}", err,);
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let engine_folder = String::from(
|
||||||
|
engine_folder
|
||||||
|
.to_str()
|
||||||
|
.expect("Failed to convert engine_folder to valid UTF-8"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let executor_worker_path = String::from(
|
||||||
|
executor_worker_path
|
||||||
|
.to_str()
|
||||||
|
.expect("Failed to convert executor_worker_path to valid UTF-8"),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok((engine_folder, executor_worker_path))
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||||
|
|
||||||
|
pub struct TensorRtLlmBackendV2 {
|
||||||
|
executor_looper: JoinHandle<()>,
|
||||||
|
post_processor_looper: JoinHandle<()>,
|
||||||
|
executor: UnboundedSender<GenerationContext>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorRtLlmBackendV2 {
|
||||||
|
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
engine_folder: P,
|
||||||
|
executor_worker_path: PP,
|
||||||
|
max_inflight_requests: usize,
|
||||||
|
) -> Result<Self, TensorRtLlmBackendError> {
|
||||||
|
let (engine_folder, executor_worker_path) =
|
||||||
|
ensure_paths_exist(engine_folder, executor_worker_path)?;
|
||||||
|
|
||||||
|
// Allocate the IPC layer to communicate with the backend
|
||||||
|
let (executor_sender, executor_receiver) = unbounded_channel();
|
||||||
|
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
|
||||||
|
|
||||||
|
// Create the FFI backend
|
||||||
|
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
|
||||||
|
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
|
||||||
|
|
||||||
|
// Executor looper is responsible for scheduling and pulling requests state at regular interval
|
||||||
|
let executor_looper = spawn_blocking(move || {
|
||||||
|
executor_status_looper(
|
||||||
|
backend,
|
||||||
|
max_inflight_requests,
|
||||||
|
executor_receiver,
|
||||||
|
post_processor_sender,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
|
||||||
|
let post_processor_looper = spawn_blocking(move || {
|
||||||
|
post_processor_looper(
|
||||||
|
tokenizer,
|
||||||
|
512,
|
||||||
|
max_inflight_requests,
|
||||||
|
post_processor_receiver,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(TensorRtLlmBackendV2 {
|
||||||
|
executor_looper,
|
||||||
|
post_processor_looper,
|
||||||
|
executor: executor_sender,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
|
||||||
|
if request.input_ids.is_none() {
|
||||||
|
return Err(ValidationError(UnsupportedModality("No token provided")));
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.top_n_tokens > 1 {
|
||||||
|
return Err(ValidationError(TopNTokensDisabled));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Is it really needed? How can it be validated before?
|
||||||
|
if request.parameters.grammar.is_some() {
|
||||||
|
return Err(ValidationError(Grammar));
|
||||||
|
}
|
||||||
|
|
||||||
|
match request.inputs.len() {
|
||||||
|
0 => Err(ValidationError(EmptyInput)),
|
||||||
|
2.. => Err(GenerationError(
|
||||||
|
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||||
|
)),
|
||||||
|
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||||
|
Chunk::Text(_) => Ok(()),
|
||||||
|
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Backend for TensorRtLlmBackendV2 {
|
||||||
|
fn schedule(
|
||||||
|
&self,
|
||||||
|
inner: ValidGenerateRequest,
|
||||||
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||||
|
Self::validate(&inner)?;
|
||||||
|
|
||||||
|
// Open-up the stream to send tokens
|
||||||
|
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
|
||||||
|
|
||||||
|
// Send the context to the executor for scheduling
|
||||||
|
let queued = Instant::now();
|
||||||
|
match self.executor.send(GenerationContext {
|
||||||
|
request: inner,
|
||||||
|
start: None,
|
||||||
|
queued,
|
||||||
|
streamer,
|
||||||
|
}) {
|
||||||
|
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
|
||||||
|
Err(_) => Err(GenerationError(
|
||||||
|
"Failed to submit request to the backend".into(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, current_health: bool) -> bool {
|
||||||
|
current_health
|
||||||
|
& !self.executor_looper.is_finished()
|
||||||
|
& !self.post_processor_looper.is_finished()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,16 @@
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::collections::HashMap;
|
use hf_hub::api::tokio::{Api, ApiBuilder};
|
||||||
use std::path::PathBuf;
|
use hf_hub::{Cache, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
|
||||||
use text_generation_router::{server, usage_stats};
|
use text_generation_router::server::get_base_tokenizer;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use text_generation_router::usage_stats::UsageStatsLevel;
|
||||||
|
use text_generation_router::{server, HubTokenizerConfig};
|
||||||
|
|
||||||
/// App Configuration
|
/// App Configuration
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -58,6 +64,130 @@ struct Args {
|
||||||
usage_stats: usage_stats::UsageStatsLevel,
|
usage_stats: usage_stats::UsageStatsLevel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_tokenizer(
|
||||||
|
tokenizer_name: &str,
|
||||||
|
tokenizer_config_path: Option<&str>,
|
||||||
|
revision: Option<&str>,
|
||||||
|
) -> Option<Tokenizer> {
|
||||||
|
// Parse Huggingface hub token
|
||||||
|
let authorization_token = std::env::var("HF_TOKEN")
|
||||||
|
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
|
let local_path = Path::new(tokenizer_name);
|
||||||
|
|
||||||
|
// Shared API builder initialization
|
||||||
|
let api_builder = || {
|
||||||
|
let mut builder = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_token(authorization_token);
|
||||||
|
|
||||||
|
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||||
|
builder = builder.with_cache_dir(cache_dir.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decide if we need to use the API based on the revision and local path
|
||||||
|
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||||
|
|
||||||
|
// Initialize API if needed
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum Type {
|
||||||
|
Api(Api),
|
||||||
|
Cache(Cache),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
let api = if use_api {
|
||||||
|
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||||
|
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||||
|
.map_err(|_| ())
|
||||||
|
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||||
|
.unwrap_or_else(|_| Cache::default());
|
||||||
|
tracing::warn!("Offline mode active using cache defaults");
|
||||||
|
Type::Cache(cache)
|
||||||
|
} else {
|
||||||
|
tracing::info!("Using the Hugging Face API");
|
||||||
|
match api_builder().build() {
|
||||||
|
Ok(api) => Type::Api(api),
|
||||||
|
Err(_) => {
|
||||||
|
tracing::warn!("Unable to build the Hugging Face API");
|
||||||
|
Type::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Type::None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load tokenizer and model info
|
||||||
|
let (
|
||||||
|
tokenizer_filename,
|
||||||
|
_config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
_preprocessor_config_filename,
|
||||||
|
_processor_config_filename,
|
||||||
|
) = match api {
|
||||||
|
Type::None => (
|
||||||
|
Some(local_path.join("tokenizer.json")),
|
||||||
|
Some(local_path.join("config.json")),
|
||||||
|
Some(local_path.join("tokenizer_config.json")),
|
||||||
|
Some(local_path.join("preprocessor_config.json")),
|
||||||
|
Some(local_path.join("processor_config.json")),
|
||||||
|
),
|
||||||
|
Type::Api(api) => {
|
||||||
|
let api_repo = api.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.unwrap_or_else(|| "main").to_string(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||||
|
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||||
|
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||||
|
};
|
||||||
|
let config_filename = api_repo.get("config.json").await.ok();
|
||||||
|
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||||
|
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||||
|
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||||
|
|
||||||
|
(
|
||||||
|
tokenizer_filename,
|
||||||
|
config_filename,
|
||||||
|
tokenizer_config_filename,
|
||||||
|
preprocessor_config_filename,
|
||||||
|
processor_config_filename,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Type::Cache(cache) => {
|
||||||
|
let repo = cache.repo(Repo::with_revision(
|
||||||
|
tokenizer_name.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.clone().unwrap_or_else(|| "main").to_string(),
|
||||||
|
));
|
||||||
|
(
|
||||||
|
repo.get("tokenizer.json"),
|
||||||
|
repo.get("config.json"),
|
||||||
|
repo.get("tokenizer_config.json"),
|
||||||
|
repo.get("preprocessor_config.json"),
|
||||||
|
repo.get("processor_config.json"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||||
|
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||||
|
{
|
||||||
|
HubTokenizerConfig::from_file(filename)
|
||||||
|
} else {
|
||||||
|
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||||
|
};
|
||||||
|
|
||||||
|
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
// Get args
|
// Get args
|
||||||
|
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run server
|
// Create the backend
|
||||||
let tokenizer = Tokenizer::from_pretrained(
|
let tokenizer = get_tokenizer(
|
||||||
tokenizer_name.clone(),
|
&tokenizer_name,
|
||||||
Some(FromPretrainedParameters {
|
tokenizer_config_path.as_deref(),
|
||||||
revision: revision.clone().unwrap_or(String::from("main")),
|
revision.as_deref(),
|
||||||
user_agent: HashMap::new(),
|
|
||||||
auth_token,
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
.await
|
||||||
|
.expect("Failed to retrieve tokenizer implementation");
|
||||||
|
|
||||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
|
||||||
|
let backend = TensorRtLlmBackendV2::new(
|
||||||
|
tokenizer,
|
||||||
|
model_id,
|
||||||
|
executor_worker,
|
||||||
|
max_concurrent_requests,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
info!("Successfully created backend");
|
||||||
|
|
||||||
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
max_concurrent_requests,
|
max_concurrent_requests,
|
||||||
|
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
None,
|
auth_token,
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
tokenizer_config_path,
|
tokenizer_config_path,
|
||||||
revision,
|
revision,
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
///
|
||||||
|
/// Extract the first line of the provided string reference.
|
||||||
|
/// If there is no lines in the buffer, it returns a string
|
||||||
|
/// which content is defined by the content of `fail`
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `s`: The string buffer to extract the first-line from
|
||||||
|
/// * `fail`: A string content which is returned if no lines are
|
||||||
|
/// present in `s`
|
||||||
|
///
|
||||||
|
/// returns: String
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
|
||||||
|
/// first_line(s, "No line in string");
|
||||||
|
/// ```
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn first_line(s: &str, fail: &str) -> String {
|
||||||
|
s.lines().next().unwrap_or(fail).to_string()
|
||||||
|
}
|
Loading…
Reference in New Issue