[TENSORRT-LLM] - Implement new looper thread based backend (#2357)

* (backend) use parking_lot crate for RwLock fairness

# Conflicts:
#	backends/trtllm/src/backend.rs

* (launcher) default new server::run parameters to false for now

* (chore) fmt ... why?

* (ffi) use const for GetSamplingConfig

* (server) expose new SchedulingError

* (trt)

* (build) setup ccache if available

* (ffi) add max_new_tokens parameters

* (backend) cleanup a bit

* (backend) expose PullNewTokens

* (ffi) cleanup again

* (ffi) add missing headers imports

* (ffi) add template specialization to catch and convert to Rust Result<T, tensorrt_llm::common::TllmException>

* (looper) new looper initial implementation

* (ffi) remove narrowing type warning

* (ffi) encode the provided user prompt within each request thread

* (misc) change scope identifiers

* (backend) implement the post_processor background thread

* (misc) missing Result types for Rust

* use blocking_recv in looper to consume awaiting_requests at max before pulling in a single step

* (server) forward auth_token to server::run

* (build) fetchcontent use archives instead of git

* (ffi) fix usage of wrong vector constructor making a capacity fill call

* (ffi) missing namespace for tle::Response

* (ffi) do not use reference capture in lambda as we are not capturing anything

* (backend) refactor & cleanup

* (Dockerfile.trtllm) delete for now

* (misc) simplify [make_]move_iterator by using c++20 type inference

* (misc) no need to move for uint32_t items

* (scheduler) rework submit/pull logic

* (post) impl postprocessing

* (misc) delete backend.rs

* (misc) rerun-if-changed all the cmake modules

* (misc) move to latest trtllm

* (fix): HOPPER_SM_MAJOR is 9 not 8

* (misc: build for sm_{75,80,86,89,90} by default

* (misc): build with trtllm 0.13.0

* (misc): increase verbosity of spdlog

* (fix): do not recreate the stateful hashmap at every it

* (misc): update dependency in trtllm dockerfile

* (misc): update dependency in trtllm dockerfile

* (misc): disable logging in release mode

* (misc): improve trtllm download script robustness

* (fix): ore fixes for Dockerfile

* misc(cuda): require 12.6

* chore(cmake): use correct policy for download_timestamp

* feat(looper): check engine and executorWorker paths exist before creating the backend

* chore(cmake): download timestamp should be before URL

* feat(looper): minor optimizations to avoid growing too much the containers

* chore(trtllm): move dockerfile to right place

* chore(trtllm): disable tokenizer parallelism by default

* chore(trtllm): fmt

* chore(trtllm): post-rebase commit

* chore(trtllm): remove unused method

* feat(trtllm): cache maxNumTokens to avoid calling JSON everytime

* misc(router): remove SchedulingError

* feat(trtllm): do not tokenize twice

* Revert "chore(trtllm): remove unused method"

This reverts commit 31747163

* chore(rebase): fix invalid references

* chore(router): add python dependency

* Lint.

* Fix bad rebase

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Funtowicz Morgan 2024-10-25 07:17:14 +02:00 committed by GitHub
parent ed87b464b4
commit 43df056eee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 791 additions and 595 deletions

68
Cargo.lock generated
View File

@ -2706,9 +2706,9 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry" name = "opentelemetry"
version = "0.23.0" version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76" checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
@ -2819,19 +2819,17 @@ dependencies = [
[[package]] [[package]]
name = "opentelemetry_sdk" name = "opentelemetry_sdk"
version = "0.23.0" version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd" checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"futures-channel", "futures-channel",
"futures-executor", "futures-executor",
"futures-util", "futures-util",
"glob", "glob",
"lazy_static",
"once_cell", "once_cell",
"opentelemetry 0.23.0", "opentelemetry 0.24.0",
"ordered-float 4.3.0",
"percent-encoding", "percent-encoding",
"rand", "rand",
"thiserror", "thiserror",
@ -4185,16 +4183,17 @@ dependencies = [
"cmake", "cmake",
"cxx", "cxx",
"cxx-build", "cxx-build",
"hashbrown 0.14.5",
"hf-hub",
"log", "log",
"parking_lot",
"pkg-config", "pkg-config",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.19.1", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tracing", "tracing",
"tracing-opentelemetry 0.24.0", "tracing-opentelemetry 0.25.0",
"tracing-subscriber", "tracing-subscriber",
] ]
@ -4212,7 +4211,7 @@ dependencies = [
"tabled", "tabled",
"text-generation-client", "text-generation-client",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -4292,7 +4291,7 @@ dependencies = [
"serde_json", "serde_json",
"sysinfo", "sysinfo",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http", "tower-http",
@ -4341,7 +4340,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4392,7 +4391,7 @@ dependencies = [
"slotmap", "slotmap",
"text-generation-router", "text-generation-router",
"thiserror", "thiserror",
"tokenizers 0.20.0", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tonic 0.10.2", "tonic 0.10.2",
@ -4514,39 +4513,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]] [[package]]
name = "tokenizers" name = "tokenizers"
version = "0.20.0" version = "0.20.0"
@ -4933,14 +4899,14 @@ dependencies = [
[[package]] [[package]]
name = "tracing-opentelemetry" name = "tracing-opentelemetry"
version = "0.24.0" version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4" checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"once_cell", "once_cell",
"opentelemetry 0.23.0", "opentelemetry 0.24.0",
"opentelemetry_sdk 0.23.0", "opentelemetry_sdk 0.24.1",
"smallvec", "smallvec",
"tracing", "tracing",
"tracing-core", "tracing-core",

View File

@ -1,23 +0,0 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -10,7 +10,7 @@ COPY . .
RUN cargo chef prepare --recipe-path recipe.json RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage # CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
ninja-build \ ninja-build \
pkg-config \ pkg-config \
python3 \ python3 \
python3-dev \
python3-setuptools \ python3-setuptools \
tar \ tar \
wget wget
@ -82,10 +83,15 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
cd backends/trtllm && \ cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
RUN apt update && apt install -y python3 && \
rm -rf /var/lib/{apt,dpkg,cache,log}/
WORKDIR /usr/local/tgi/bin WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH" ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt

View File

@ -1,5 +1,17 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0) project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD 20)
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located") set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features # We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml) find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies #### #### External dependencies ####
include(cmake/fmt.cmake) include(cmake/fmt.cmake)

View File

@ -10,16 +10,17 @@ async-trait = "0.1"
async-stream = "0.3" async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
cxx = "1.0" cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] } log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" } text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] } tokenizers = { workspace = true }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.15"
thiserror = "1.0.62" thiserror = "1.0.63"
tracing = "0.1" tracing = "0.1"
tracing-opentelemetry = "0.24" tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
parking_lot = "0.12"
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"

View File

@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"]; const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST"); const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5"; const CUDA_REQUIRED_VERSION: &str = "12.6";
const MPI_REQUIRED_VERSION: &str = "4.1"; const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX"); const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR"); const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
// Build the backend implementation through CMake // Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi"); let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt"); let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path); let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() { if !install_path.is_absolute() {
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
(PathBuf::from(install_path), deps_folder) (PathBuf::from(install_path), deps_folder)
} }
fn build_ffi_layer(deps_folder: &PathBuf) { fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm"; CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs") cxx_build::bridge("src/lib.rs")
.static_flag(true) .static_flag(true)
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
.include("/usr/local/tensorrt/include") .include("/usr/local/tensorrt/include")
.file("src/ffi.cpp") .file("src/ffi.cpp")
.std("c++20") .std("c++20")
.define("NDEBUG", ndebug)
.compile("tgi_trtllm_backend"); .compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt"); println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=include/backend.h"); println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp"); println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h"); println!("cargo:rerun-if-changed=include/ffi.h");
@ -115,7 +125,7 @@ fn main() {
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir); let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above // Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder); build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path // Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION); probe!("ompi", MPI_REQUIRED_VERSION);

View File

@ -1,6 +1,6 @@
FetchContent_Declare( FetchContent_Declare(
fmt fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt DOWNLOAD_EXTRACT_TIMESTAMP
GIT_TAG 11.0.1 URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
) )
FetchContent_MakeAvailable(fmt) FetchContent_MakeAvailable(fmt)

View File

@ -1,5 +1,6 @@
fetchcontent_declare( fetchcontent_declare(
json json
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
) )
fetchcontent_makeavailable(json) fetchcontent_makeavailable(json)

View File

@ -11,7 +11,7 @@ endif ()
fetchcontent_declare( fetchcontent_declare(
spdlog spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git DOWNLOAD_EXTRACT_TIMESTAMP
GIT_TAG v1.14.1 URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
) )
fetchcontent_makeavailable(spdlog) fetchcontent_makeavailable(spdlog)

View File

@ -23,8 +23,9 @@ endif ()
fetchcontent_declare( fetchcontent_declare(
trtllm trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1 GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
GIT_SHALLOW FALSE GIT_SHALLOW FALSE
DOWNLOAD_EXTRACT_TIMESTAMP
) )
fetchcontent_makeavailable(trtllm) fetchcontent_makeavailable(trtllm)

View File

@ -23,6 +23,12 @@ namespace huggingface::tgi::backends {
using RequestId = tle::IdType; using RequestId = tle::IdType;
using TokenId = tle::TokenIdType; using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/** /**
* Initialize all the components required by TRTLLM. * Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine * It is required to call this function before attempting to load any engine
@ -54,7 +60,7 @@ namespace huggingface::tgi::backends {
float_t repetition_penalty, float_t repetition_penalty,
float_t frequency_penalty, float_t frequency_penalty,
uint64_t seed uint64_t seed
); ) noexcept;
/** /**
* *
@ -64,18 +70,15 @@ namespace huggingface::tgi::backends {
const json config; const json config;
tle::Executor executor; tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
public: public:
explicit TensorRtLlmBackend( explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder, const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
); );
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/** /**
* Query the executor for the number of token available for pulling * Query the executor for the number of token available for pulling
* @return * @return
@ -95,25 +98,16 @@ namespace huggingface::tgi::backends {
*/ */
[[nodiscard]] RequestId Submit( [[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens, const std::vector<TokenId> &tokens,
int32_t topK, const uint32_t maxNewTokens,
float_t topP, const int32_t topK,
float_t temperature, const float_t topP,
float_t repetition_penalty, const float_t temperature,
float_t frequency_penalty, const float_t repetition_penalty,
uint64_t seed const float_t frequency_penalty,
const uint64_t seed
); );
/** [[nodiscard]] std::vector<tle::Response> PullNewTokens();
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
}; };
} }

View File

@ -5,20 +5,31 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H #ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H #define TGI_TRTLLM_BACKEND_FFI_H
#include <cmath>
#include <cstddef> #include <cstddef>
#include <memory>
#include "backend.h" #include "backend.h"
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl; class TensorRtLlmBackendImpl;
} }
// Template to support returning error from TllmException back to Rust in a Result<>
#include <tensorrt_llm/common/tllmException.h>
namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
#include "backends/trtllm/src/lib.rs.h" #include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends { namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend { class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public: public:
/*** /***
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
*/ */
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker); TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/*** /***
* *
* @param tokens * @param tokens
* @param maxNewTokens
* @param topK * @param topK
* @param topP * @param topP
* @param temperature * @param temperature
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
*/ */
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]] [[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed); float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/*** /***
* *
* @param requestId
* @param ctx
* @param callback
* @return * @return
*/ */
size_t StreamTokens( std::unique_ptr<std::vector<GenerationStep>> PullTokens();
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
}; };
/*** /***

View File

@ -14,7 +14,7 @@
namespace huggingface::hardware::cuda { namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8 #define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8 #define HOPPER_SM_MAJOR 9
/** /**
* Store information about the version of the CUDA Compute Capabilities detected on the device * Store information about the version of the CUDA Compute Capabilities detected on the device

View File

@ -1,3 +1,4 @@
#include <cstdlib>
#include <fstream> #include <fstream>
#include <fmt/ranges.h> #include <fmt/ranges.h>
@ -8,10 +9,23 @@
#include "hardware.h" #include "hardware.h"
void huggingface::tgi::backends::InitializeBackend() { void huggingface::tgi::backends::InitializeBackend() {
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
SPDLOG_INFO("Initializing Backend..."); SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2(); nvmlInit_v2();
initTrtLlmPlugins(); initTrtLlmPlugins();
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
const auto numGpus = huggingface::hardware::cuda::GetNumDevices(); const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) { if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value()); SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
@ -22,7 +36,7 @@ void huggingface::tgi::backends::InitializeBackend() {
[[nodiscard]] [[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) { tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1); tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
// Retrieve the compute capabilities to enable some options at runtime // Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities(); const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
@ -55,12 +69,13 @@ tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &co
} }
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig( tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK, const uint32_t topK,
float_t topP, const float_t topP,
float_t temperature, const float_t temperature,
float_t repetition_penalty, const float_t repetition_penalty,
float_t frequency_penalty, const float_t frequency_penalty,
uint64_t seed) { const uint64_t seed) noexcept {
return tle::SamplingConfig( return tle::SamplingConfig(
1, // TGI only use a single beam 1, // TGI only use a single beam
topK, topK,
@ -83,26 +98,29 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &executorWorker const std::filesystem::path &executorWorker
) : ) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))), config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor( executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
enginesFolder, GetExecutorConfig(config, executorWorker.string())) {
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>()); SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const { // Cache variables
return executor.canEnqueueRequests(); maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
} }
[[nodiscard("Returned number of requests needs to be consumed")]] [[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const { size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
return executor.getNumResponsesReady(); const auto numResponses = executor.getNumResponsesReady();
#ifndef NDEBUG
if(numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
#endif
return numResponses;
} }
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]] [[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit( tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens, const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK, const int32_t topK,
const float_t topP, const float_t topP,
const float_t temperature, const float_t temperature,
@ -110,37 +128,23 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const float_t frequency_penalty, const float_t frequency_penalty,
const uint64_t seed const uint64_t seed
) { ) {
#ifdef NDEBUG const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
SPDLOG_DEBUG( #ifndef NDEBUG
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"), {
tokens.size(), const auto &iterations = executor.getLatestIterationStats();
executor.getLatestIterationStats().back().numActiveRequests const auto &lastIteration = iterations.front();
);
#else SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG( SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"), SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
fmt::join(tokens, ", "), }
executor.getLatestIterationStats().front().numActiveRequests
);
#endif #endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed); const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false); const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest( return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
tle::Request{tokens, maxNewTokens, true, sampling, output});
} }
[[nodiscard("Generated tokens result must be used")]] std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) { return executor.awaitResponses();
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
} }

View File

@ -2,12 +2,13 @@
set -ex set -ex
TRT_VER="10.2.0.19" TRT_VER_BASE="10.4.0"
CUDA_VER="12.5" TRT_VER_FULL="${TRT_VER_BASE}.26"
CUDNN_VER="9.2.1.18-1" CUDA_VER="12.6"
NCCL_VER="2.22.3-1+cuda12.5" CUDNN_VER="9.5.0.50-1"
CUBLAS_VER="12.5.3.2-1" NCCL_VER="2.22.3-1+cuda12.6"
NVRTC_VER="12.5.82-1" CUBLAS_VER="12.6.3.3-1"
NVRTC_VER="12.6.77-1"
for i in "$@"; do for i in "$@"; do
case $i in case $i in
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
ARCH=$(uname -m) ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then if [[ $(apt list --installed | grep libcudnn9) ]]; then
@ -71,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() { install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))') #PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}") #PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5" TRT_CUDA_VERSION="12.6"
if [ -z "$RELEASE_URL_TRT" ];then if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH} ARCH=${TRT_TARGETARCH}
@ -79,12 +81,12 @@ install_tensorrt() {
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/ tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar rm -rf /tmp/TensorRT.tar
} }

View File

@ -1,330 +0,0 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
// use tokio::sync::RwLock;
use parking_lot::RwLock;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -1,9 +1,16 @@
use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
use text_generation_router::server; use text_generation_router::server;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum TensorRtLlmBackendError { pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")] #[error("Tokenizer error: {0}")]
Tokenizer(String), Tokenizer(String),
#[error("Argument validation error: {0}")] #[error("Argument validation error: {0}")]

View File

@ -3,11 +3,13 @@
// //
#pragma once #pragma once
#include <cmath> #include <algorithm>
#include <exception> #include <exception>
#include <filesystem> #include <filesystem>
#include <functional>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
#include <ranges>
#include <vector> #include <vector>
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
@ -20,61 +22,59 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
) : TensorRtLlmBackend(engineFolder, executorWorker) {} ) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty, rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
float_t frequency_penalty, uint64_t seed) { int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed) {
// This will copy all the items from the initial slice // This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end())); std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit( return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed); std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
} }
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
const uint64_t requestId, huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
huggingface::tgi::backends::GenerationContext *ctx, const auto responses = TensorRtLlmBackend::PullNewTokens();
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
size_t numTokens = 0; auto steps = std::make_unique<std::vector<GenerationStep>>();
for (const auto &item: Poll(requestId)) { steps->reserve(responses.size());
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
const auto token = decoded.outputTokenIds[0][0]; #ifndef NDEBUG
const auto isFinal = decoded.isFinal; SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
const auto logProb = decoded.logProbs.value()[0][0]; #endif
++numTokens; // Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); const auto reqId = r.getRequestId();
step = huggingface::tgi::backends::GenerationStep{ if (!r.hasError()) {
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string()) const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
}; };
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else { } else {
// TODO : Return rest::Result with error return GenerationStep{
const auto what = item.getErrorMsg(); reqId,
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); 0,
step = huggingface::tgi::backends::GenerationStep{ 0.0,
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what) true,
true,
std::move(r.getErrorMsg())
}; };
} }
});
callback(std::move(ctx), std::move(step)); return steps;
}
return numTokens;
} }
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl> std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) { huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins // Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend(); InitializeBackend();

View File

@ -1,14 +1,16 @@
pub use backend::{GenerationContext, TensorRtLlmBackend}; pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors; pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")] #[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi { mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
request_id: u64,
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
@ -16,10 +18,6 @@ mod ffi {
error_msg: String, error_msg: String,
} }
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp"); include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +42,7 @@ mod ffi {
fn CreateTensorRtLlmBackend( fn CreateTensorRtLlmBackend(
engine_folder: &str, engine_folder: &str,
executor_worker: &str, executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>; ) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
#[rust_name = "num_responses_ready"] #[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize; fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -56,23 +51,18 @@ mod ffi {
fn Submit( fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32], tokens: &[u32],
max_new_tokens: u32,
top_k: i32, top_k: i32,
top_p: f32, top_p: f32,
temperature: f32, temperature: f32,
repetition_penalty: f32, repetition_penalty: f32,
frequency_penalty: f32, frequency_penalty: f32,
seed: u64, seed: u64,
) -> u64; ) -> Result<u64>;
#[rust_name = "stream_tokens"] #[rust_name = "pull_tokens"]
unsafe fn StreamTokens( fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>, self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64, ) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
} }
} }

View File

@ -0,0 +1,395 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>;
struct IdentifiableRequest<T> {
request_id: u64,
inner: T,
}
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
}
}
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_processor_looper(
tokenizer: Tokenizer,
max_num_tokens: usize,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(max_num_tokens);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
}
}
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result<Self, TensorRtLlmBackendError> {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper(
tokenizer,
512,
max_inflight_requests,
post_processor_receiver,
)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, current_health: bool) -> bool {
current_health
& !self.executor_looper.is_finished()
& !self.post_processor_looper.is_finished()
}
}

View File

@ -1,10 +1,16 @@
use std::path::{Path, PathBuf};
use clap::Parser; use clap::Parser;
use std::collections::HashMap; use hf_hub::api::tokio::{Api, ApiBuilder};
use std::path::PathBuf; use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend; use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::{server, usage_stats}; use text_generation_router::server::get_base_tokenizer;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
/// App Configuration /// App Configuration
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -58,6 +64,130 @@ struct Args {
usage_stats: usage_stats::UsageStatsLevel, usage_stats: usage_stats::UsageStatsLevel,
} }
async fn get_tokenizer(
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
let local_path = Path::new(tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
_config_filename,
tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), TensorRtLlmBackendError> { async fn main() -> Result<(), TensorRtLlmBackendError> {
// Get args // Get args
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
))); )));
} }
// Run server // Create the backend
let tokenizer = Tokenizer::from_pretrained( let tokenizer = get_tokenizer(
tokenizer_name.clone(), &tokenizer_name,
Some(FromPretrainedParameters { tokenizer_config_path.as_deref(),
revision: revision.clone().unwrap_or(String::from("main")), revision.as_deref(),
user_agent: HashMap::new(),
auth_token,
}),
) )
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; .await
.expect("Failed to retrieve tokenizer implementation");
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?; info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend");
// Run server
server::run( server::run(
backend, backend,
max_concurrent_requests, max_concurrent_requests,
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
validation_workers, validation_workers,
None, auth_token,
tokenizer_name, tokenizer_name,
tokenizer_config_path, tokenizer_config_path,
revision, revision,

View File

@ -0,0 +1,22 @@
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}