wip
wip refacto refacto Initial setup for CXX binding to TRTLLM Working FFI call for TGI and TRTLLM backend Remove unused parameters annd force tokenizer name to be set Overall build TRTLLM and deps through CMake build system Enable end to end CMake build First version loading engines and making it ready for inference Remembering to check how we can detect support for chunked context Move to latest TensorRT-LLM version Specify which default log level to use depending on CMake build type make leader executor mode working unconditionally call InitializeBackend on the FFI layer bind to CUDA::nvml to retrieve compute capabilities at runtime updated logic and comment to detect cuda compute capabilities implement the Stream method to send new tokens through a callback use spdlog release 1.14.1 moving forward update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c correctly tell cmake to build dependent tensorrt-llm required libraries create cmake install target to put everything relevant in installation folder add auth_token CLI argument to provide hf hub authentification token allow converting huggingface::tokenizers error to TensorRtLlmBackendError use correct include for spdlog include guard to build example in cmakelists working setup of the ffi layer remove fmt import use external fmt lib end to end ffi flow working make sure to track include/ffi.h to trigger rebuild from cargo impl the rust backend which currently cannot move the actual computation in background thread expose shutdown function at ffi layer impl RwLock scenario for TensorRtLllmBackend oops missing c++ backend definitions compute the number of maximum new tokens for each request independently make sure the context is not dropped in the middle of the async decoding. remove unnecessary log add all the necessary plumbery to return the generated content update invalid doc in cpp file correctly forward back the log probabilities remove unneeded scope variable for now refactor Stream impl for Generation to factorise code expose the internal missing start/queue timestamp forward tgi parameters rep/freq penalty add some more validation about grammar not supported define a shared struct to hold the result of a decoding step expose information about potential error happening while decoding remove logging add logging in case of decoding error make sure executor_worker is provided add initial Dockerfile for TRTLLM backend add some more information in CMakeLists.txt to correctly install executorWorker add some more information in CMakeLists.txt to correctly find and install nvrtc wrapper simplify prebuilt trtllm libraries name definition do the same name definition stuff for tensorrt_llm_executor_static leverage pkg-config to probe libraries paths and reuse new install structure from cmake fix bad copy/past missing nvinfer linkage direction align all the linker search dependency add missing pkgconfig folder for MPI in Dockerfile correctly setup linking search path for runtime layer fix missing / before tgi lib path adding missing ld_library_path for cuda stubs in Dockerfile update tgi entrypoint commenting out Python part for TensorRT installation refactored docker image move to TensorRT-LLM v0.11.0 make docker linter happy with same capitalization rule fix typo refactor the compute capabilities detection along with num gpus update TensorRT-LLM to latest version update TensorRT install script to latest update build.rs to link to cuda 12.5 add missing dependant libraries for linking clean up a bit install to decoder_attention target add some custom stuff for nccl linkage fix envvar CARGO_CFG_TARGET_ARCH set at runtime vs compile time use std::env::const::ARCH make sure variable live long enough... look for cuda 12.5 add some more basic info in README.md
This commit is contained in:
parent
0b95693fb8
commit
ddbbf6b50c
|
@ -2,3 +2,5 @@ aml
|
|||
target
|
||||
server/transformers
|
||||
server/flash-attention
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
|
|
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
|
@ -1,11 +1,11 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
# "benchmark",
|
||||
"backends/v3",
|
||||
# "backends/client",
|
||||
"backends/grpc-metadata",
|
||||
"launcher"
|
||||
]
|
||||
, "backends/trtllm"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# All the tooling for CUDA
|
||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
WORKDIR /usr/src/tgi/backends/trtllm
|
||||
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
||||
|
||||
COPY . /usr/src/tgi
|
||||
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
||||
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
||||
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
||||
|
||||
# All the tooling for Rust
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Include CUDA related libraries and tools to the Rust based image
|
||||
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
||||
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
||||
|
||||
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
|
@ -0,0 +1,69 @@
|
|||
cmake_minimum_required(VERSION 3.20)
|
||||
|
||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||
|
||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::nvml)
|
||||
|
||||
#### External dependencies ####
|
||||
include(cmake/fmt.cmake)
|
||||
include(cmake/json.cmake)
|
||||
include(cmake/spdlog.cmake)
|
||||
include(cmake/trtllm.cmake)
|
||||
|
||||
# Let's build TRTLLM as part of CMake
|
||||
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
||||
|
||||
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
|
||||
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
|
||||
|
||||
# TGI TRTLLM Backend definition
|
||||
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
|
||||
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::nvml)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_EXAMPLES})
|
||||
add_executable(tgi_trtllm_backend_example bin/example.cpp)
|
||||
target_link_libraries(tgi_trtllm_backend_example PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tgi_trtllm_backend_impl)
|
||||
target_link_libraries(tgi_trtllm_backend_example PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||
endif ()
|
||||
|
||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||
|
||||
#### Unit Tests ####
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
message(STATUS "Building tests")
|
||||
FetchContent_Declare(
|
||||
Catch2
|
||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||
GIT_TAG v3.6.0
|
||||
)
|
||||
FetchContent_MakeAvailable(Catch2)
|
||||
|
||||
add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
|
||||
target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||
include(CTest)
|
||||
include(Catch)
|
||||
catch_discover_tests(tgi_trtllm_backend_tests)
|
||||
endif ()
|
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
cxx = "1.0"
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.15"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
thiserror = "1.0.62"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.24"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
log = { version = "0.4", features = [] }
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||
pkg-config = "0.3"
|
|
@ -0,0 +1,100 @@
|
|||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
||||
ARG OMPI_VERSION="4.1.6"
|
||||
|
||||
# Build dependencies resolver stage
|
||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
# CUDA dependent dependencies resolver stage
|
||||
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt install -y \
|
||||
build-essential \
|
||||
cmake \
|
||||
curl \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
git-lfs \
|
||||
libssl-dev \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
python3 \
|
||||
python3-setuptools \
|
||||
tar \
|
||||
wget
|
||||
|
||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||
|
||||
# Install OpenMPI
|
||||
FROM cuda-builder AS mpi-builder
|
||||
ARG OMPI_VERSION
|
||||
|
||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
||||
mkdir /usr/src/mpi && \
|
||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||
cd /usr/src/mpi && \
|
||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
|
||||
make -j all && \
|
||||
make install && \
|
||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||
|
||||
# Install TensorRT
|
||||
FROM cuda-builder AS trt-builder
|
||||
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||
/opt/install_tensorrt.sh
|
||||
|
||||
# Build Backend
|
||||
FROM cuda-builder AS tgi-builder
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||
chmod -R a+w /root/.rustup && \
|
||||
chmod -R a+w /root/.cargo
|
||||
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
RUN cargo install cargo-chef
|
||||
|
||||
# Cache dependencies
|
||||
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
# Build actual TGI
|
||||
ARG CUDA_ARCH_LIST
|
||||
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
|
||||
COPY . .
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
|
||||
|
||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||
WORKDIR /usr/local/tgi/bin
|
||||
|
||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||
|
||||
FROM runtime
|
||||
|
||||
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||
|
||||
ENTRYPOINT ["./text-generation-launcher"]
|
||||
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
|
@ -0,0 +1,46 @@
|
|||
# Text Generation Inference - TensorRT-LLM Backend Implementation
|
||||
|
||||
## Description
|
||||
|
||||
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
|
||||
|
||||
## Simplified Request Sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
actor User
|
||||
participant TextGenerationInference.HttpServer
|
||||
participant TextGenerationInference.TensorRtLlmBackend
|
||||
participant TextGenerationInference.TensorRtLlmWorkerThread
|
||||
participant TensorRtLlm.Executor
|
||||
participant Nvidia.Gpu
|
||||
User ->> TextGenerationInference.HttpServer: POST /generate
|
||||
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
|
||||
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
|
||||
activate Nvidia.Gpu
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
|
||||
rect rgb(10, 92, 54)
|
||||
loop every 100us
|
||||
rect rgb(15, 81, 50)
|
||||
alt Acquire lock to query executor
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
|
||||
else There are new generated tokens
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
|
||||
rect rgb(11, 110, 79)
|
||||
alt Generated token is final
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
|
||||
else Generated token is not final
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
deactivate Nvidia.Gpu
|
||||
end
|
||||
end
|
||||
|
||||
```
|
|
@ -0,0 +1,151 @@
|
|||
use std::env;
|
||||
use std::env::consts::ARCH;
|
||||
use std::path::{absolute, PathBuf};
|
||||
|
||||
use cxx_build::CFG;
|
||||
use pkg_config;
|
||||
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
||||
|
||||
// Dependencies
|
||||
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
|
||||
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
||||
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
|
||||
("dylib", "tensorrt-llm"),
|
||||
("static", "tensorrt_llm_executor_static"),
|
||||
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
||||
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
||||
("dylib", "decoder_attention"),
|
||||
];
|
||||
|
||||
macro_rules! probe {
|
||||
($name: expr, $version: expr) => {
|
||||
if let Err(_) = pkg_config::probe_library($name) {
|
||||
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||
.expect(&format!("Failed to locate {}", $name));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
||||
// Build the backend implementation through CMake
|
||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
||||
|
||||
let mut install_path = PathBuf::from(install_path);
|
||||
if !install_path.is_absolute() {
|
||||
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
|
||||
}
|
||||
|
||||
let _ = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.profile(match is_debug {
|
||||
true => "Debug",
|
||||
false => "Release",
|
||||
})
|
||||
.env("OPT_LEVEL", opt_level)
|
||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||
.build();
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
let deps_folder = out_dir.join("build").join("_deps");
|
||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||
let dep_name = match is_debug {
|
||||
true => format!("{}d", dependency),
|
||||
false => String::from(dependency),
|
||||
};
|
||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
// Emit linkage information from the artifacts we just built
|
||||
let install_lib_path = install_path.join("lib");
|
||||
|
||||
println!(
|
||||
r"cargo:warning=Adding link search path: {}",
|
||||
install_lib_path.display()
|
||||
);
|
||||
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
||||
|
||||
(PathBuf::from(install_path), deps_folder)
|
||||
}
|
||||
|
||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||
CFG.include_prefix = "backends/trtllm";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.static_flag(true)
|
||||
.include(deps_folder.join("fmt-src").join("include"))
|
||||
.include(deps_folder.join("spdlog-src").join("include"))
|
||||
.include(deps_folder.join("json-src").join("include"))
|
||||
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
|
||||
.include("/usr/local/cuda/include")
|
||||
.include("/usr/local/tensorrt/include")
|
||||
.file("src/ffi.cpp")
|
||||
.std("c++20")
|
||||
.compile("tgi_trtllm_backend");
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
println!("cargo:rerun-if-changed=include/backend.h");
|
||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||
println!("cargo:rerun-if-changed=src/ffi.cpp");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Misc variables
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let build_profile = env::var("PROFILE").unwrap();
|
||||
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||
"debug" => (true, "0"),
|
||||
_ => (false, "3"),
|
||||
};
|
||||
|
||||
// Build the backend
|
||||
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||
|
||||
// Build the FFI layer calling the backend above
|
||||
build_ffi_layer(&deps_folder);
|
||||
|
||||
// Emit linkage search path
|
||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||
|
||||
// Probe CUDA & co. with pkg-config
|
||||
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
|
||||
probe!(name, CUDA_REQUIRED_VERSION);
|
||||
});
|
||||
|
||||
// NCCL is slightly trickier because it might not have a pkgconfig installed
|
||||
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
|
||||
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
|
||||
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||
|
||||
// TensorRT
|
||||
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
|
||||
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nvinfer");
|
||||
|
||||
// TensorRT-LLM
|
||||
TENSORRT_LLM_TRANSITIVE_DEPS
|
||||
.iter()
|
||||
.for_each(|(link_type, name)| {
|
||||
println!("cargo:rustc-link-lib={}={}", link_type, name);
|
||||
});
|
||||
|
||||
// Backend
|
||||
BACKEND_DEPS.iter().for_each(|name| {
|
||||
println!("cargo:rustc-link-lib=static={}", name);
|
||||
});
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 11.0.1
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
|
@ -0,0 +1,5 @@
|
|||
fetchcontent_declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||
)
|
||||
fetchcontent_makeavailable(json)
|
|
@ -0,0 +1,17 @@
|
|||
set(SPDLOG_USE_FMT ON)
|
||||
set(SPDLOG_BUILD_SHARED OFF)
|
||||
set(SPDLOG_FMT_EXTERNAL ON)
|
||||
|
||||
# Define the level at which SPDLOG_ compilation level is defined
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||
else ()
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
spdlog
|
||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||
GIT_TAG v1.14.1
|
||||
)
|
||||
fetchcontent_makeavailable(spdlog)
|
|
@ -0,0 +1,42 @@
|
|||
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||
|
||||
set(USE_CXX11_ABI ON)
|
||||
set(BUILD_PYT OFF)
|
||||
set(BUILD_PYBIND OFF)
|
||||
set(BUILD_MICRO_BENCHMARKS OFF)
|
||||
set(BUILD_BENCHMARKS OFF)
|
||||
set(BUILD_TESTS OFF)
|
||||
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
||||
|
||||
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
set(FAST_BUILD ON)
|
||||
set(NVTX_DISABLE OFF)
|
||||
else ()
|
||||
set(FAST_BUILD OFF)
|
||||
set(FAST_MATH ON)
|
||||
set(NVTX_DISABLE ON)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
trtllm
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||
GIT_TAG bca9a33b022dc6a924bf7913137feed3d28b602d
|
||||
GIT_SHALLOW FALSE
|
||||
)
|
||||
fetchcontent_makeavailable(trtllm)
|
||||
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
|
||||
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
|
||||
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
|
||||
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
|
||||
CACHE INTERNAL "nvrtc wrapper library path")
|
||||
|
||||
# The same Executor Static library
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
|
|
@ -0,0 +1,121 @@
|
|||
//
|
||||
// Created by Morgan Funtowicz on 6/30/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_H
|
||||
#define TGI_TRTLLM_BACKEND_H
|
||||
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
#include <span>
|
||||
#include <vector>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <tensorrt_llm/runtime/common.h>
|
||||
#include <tensorrt_llm/executor/executor.h>
|
||||
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
using RequestId = tle::IdType;
|
||||
using TokenId = tle::TokenIdType;
|
||||
|
||||
/**
|
||||
* Initialize all the components required by TRTLLM.
|
||||
* It is required to call this function before attempting to load any engine
|
||||
*/
|
||||
void InitializeBackend();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param config TensorRT-LLM configuration object
|
||||
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||
* @return
|
||||
*/
|
||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||
|
||||
/**
|
||||
* Get the sampling configuration from the parameters provided by TGI
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
tle::SamplingConfig GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
class TensorRtLlmBackend {
|
||||
private:
|
||||
const json config;
|
||||
tle::Executor executor;
|
||||
|
||||
public:
|
||||
explicit TensorRtLlmBackend(
|
||||
const std::filesystem::path &engineFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
);
|
||||
|
||||
/**
|
||||
* Indicate if the backend is ready to accept incoming request
|
||||
* @return true if ready, false otherwise
|
||||
*/
|
||||
[[nodiscard]] bool IsReady() const;
|
||||
|
||||
/**
|
||||
* Query the executor for the number of token available for pulling
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] size_t NumResponsesReady() const;
|
||||
|
||||
/**
|
||||
* Submit a new generation task to the executor
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return Request id related to this generation for reference
|
||||
*/
|
||||
[[nodiscard]] RequestId Submit(
|
||||
const std::vector<TokenId> &tokens,
|
||||
int32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param requestId The request id to poll the generation results
|
||||
* @return
|
||||
*/
|
||||
std::vector<tle::Response> Poll(RequestId requestId);
|
||||
|
||||
/**
|
||||
* Stop the underlying executor
|
||||
*/
|
||||
void Shutdown();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_H
|
|
@ -0,0 +1,75 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/11/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||
|
||||
#include <cstddef>
|
||||
#include "backend.h"
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
class TensorRtLlmBackendImpl;
|
||||
}
|
||||
|
||||
#include "backends/trtllm/src/lib.rs.h"
|
||||
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
||||
// struct GenerationContext;
|
||||
|
||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||
public:
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @param executorWorker
|
||||
*/
|
||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||
|
||||
/***
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
bool IsReady() const;
|
||||
|
||||
/***
|
||||
*
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||
uint64_t
|
||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param requestId
|
||||
* @param ctx
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback);
|
||||
};
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_FFI_H
|
|
@ -0,0 +1,59 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/23/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
#define TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <fmt/base.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
namespace huggingface::hardware::cuda {
|
||||
|
||||
#define AMPERE_SM_MAJOR 8
|
||||
#define HOPPER_SM_MAJOR 8
|
||||
|
||||
/**
|
||||
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||
*/
|
||||
struct CudaComputeCapabilities {
|
||||
int32_t major;
|
||||
int32_t minor;
|
||||
|
||||
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
||||
|
||||
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
||||
};
|
||||
|
||||
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
||||
// Get the compute capabilities of the current hardware
|
||||
nvmlDevice_t device;
|
||||
CudaComputeCapabilities capabilities{0, 0};
|
||||
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
||||
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
||||
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
|
||||
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
|
||||
}
|
||||
}
|
||||
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
|
||||
* @return
|
||||
*/
|
||||
std::optional<size_t> GetNumDevices() {
|
||||
uint32_t numGpus = 0;
|
||||
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
|
||||
return std::optional(numGpus);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
|
|
@ -0,0 +1,146 @@
|
|||
#include <fstream>
|
||||
|
||||
#include <fmt/ranges.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
#include "backend.h"
|
||||
#include "hardware.h"
|
||||
|
||||
void huggingface::tgi::backends::InitializeBackend() {
|
||||
SPDLOG_INFO("Initializing Backend...");
|
||||
nvmlInit_v2();
|
||||
initTrtLlmPlugins();
|
||||
|
||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||
if (numGpus.has_value()) {
|
||||
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||
} else {
|
||||
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||
tle::ExecutorConfig execConfig(1);
|
||||
|
||||
// Retrieve the compute capabilities to enable some options at runtime
|
||||
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||
|
||||
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kLEADER,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt
|
||||
));
|
||||
} else { // Multiple engines -> using orchestrator mode (MPI involved)
|
||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kORCHESTRATOR,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
tle::OrchestratorConfig(true, workerPath, nullptr, true)
|
||||
));
|
||||
}
|
||||
|
||||
// Define some configuration variables
|
||||
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
||||
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
|
||||
return execConfig;
|
||||
}
|
||||
|
||||
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed) {
|
||||
return tle::SamplingConfig(
|
||||
1, // TGI only use a single beam
|
||||
topK,
|
||||
topP,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
seed,
|
||||
temperature,
|
||||
temperature,
|
||||
std::nullopt,
|
||||
repetition_penalty,
|
||||
std::nullopt,
|
||||
frequency_penalty
|
||||
);
|
||||
}
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
const std::filesystem::path &enginesFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
) :
|
||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||
executor(
|
||||
enginesFolder,
|
||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string()
|
||||
)) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
}
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||
return executor.canEnqueueRequests();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||
return executor.getNumResponsesReady();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
) {
|
||||
#ifdef NDEBUG
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||
tokens.size(),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
#else
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||
fmt::join(tokens, ", "),
|
||||
executor.getLatestIterationStats().front().numActiveRequests
|
||||
);
|
||||
#endif
|
||||
|
||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||
return executor.enqueueRequest(
|
||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||
}
|
||||
|
||||
[[nodiscard("Generated tokens result must be used")]]
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
||||
return executor.awaitResponses(requestId);
|
||||
}
|
||||
|
||||
|
||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
||||
SPDLOG_INFO("Shutting down executor");
|
||||
executor.shutdown();
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="10.2.0.19"
|
||||
CUDA_VER="12.5"
|
||||
CUDNN_VER="9.2.1.18-1"
|
||||
NCCL_VER="2.22.3-1+cuda12.5"
|
||||
CUBLAS_VER="12.5.3.2-1"
|
||||
NVRTC_VER="12.5.82-1"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
--TRT_VER=?*) TRT_VER="${i#*=}";;
|
||||
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
|
||||
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
|
||||
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
|
||||
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
|
||||
*) ;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
install_ubuntu_requirements() {
|
||||
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
||||
|
||||
apt-get update
|
||||
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libnccl) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libnccl*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libcublas) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcublas*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
|
||||
fi
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
|
||||
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
|
||||
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
|
||||
# NVRTC static library doesn't exist in NGC PyTorch container.
|
||||
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
}
|
||||
|
||||
install_centos_requirements() {
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
yum -y update
|
||||
yum -y install epel-release
|
||||
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
|
||||
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
|
||||
yum clean all
|
||||
}
|
||||
|
||||
install_tensorrt() {
|
||||
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
TRT_CUDA_VERSION="12.5"
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
ARCH=${TRT_TARGETARCH}
|
||||
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
|
||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||
fi
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||
rm -rf /tmp/TensorRT.tar
|
||||
}
|
||||
|
||||
# Install base packages depending on the base OS
|
||||
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||
case "$ID" in
|
||||
debian)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
ubuntu)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
centos)
|
||||
install_centos_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
*)
|
||||
echo "Unable to determine OS..."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
|
@ -0,0 +1,329 @@
|
|||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::str::FromStr;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use log::{error, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{Instant, sleep};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{instrument, Level, span};
|
||||
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub(crate) struct Generation {
|
||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
done: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
pub struct GenerationContext {
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
tokens: Vec<u32>,
|
||||
done: Arc<AtomicBool>,
|
||||
queued: Instant,
|
||||
start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Stream for Generation {
|
||||
type Item = usize;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||
});
|
||||
|
||||
if !self.done.load(Ordering::Relaxed) {
|
||||
let backend = pin!(self.executor.read());
|
||||
let status = match backend.poll(ctx) {
|
||||
Poll::Ready(executor_r) => {
|
||||
let ready = executor_r.num_responses_ready();
|
||||
if ready == 0 {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(Some(ready))
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
};
|
||||
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_micros(*interval)).await;
|
||||
waker.wake();
|
||||
});
|
||||
|
||||
status
|
||||
} else {
|
||||
Poll::Ready(None) // end of stream
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(1, None)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||
|
||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||
pub struct TensorRtLlmBackend {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
|
||||
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
||||
// the number of available tokens (read only) in the Generation stream
|
||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
}
|
||||
|
||||
impl TensorRtLlmBackend {
|
||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||
tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
executor_worker_path: PP,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
Ok(TensorRtLlmBackend {
|
||||
tokenizer: Arc::new(tokenizer),
|
||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||
engine_folder.as_ref().to_str().unwrap(),
|
||||
executor_worker_path.as_ref().to_str().unwrap(),
|
||||
))),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||
if request.top_n_tokens > 1 {
|
||||
return Err(InferError::ValidationError(
|
||||
ValidationError::TopNTokensDisabled,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Is it really needed? How can it be validated before?
|
||||
if request.parameters.grammar.is_some() {
|
||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
||||
}
|
||||
|
||||
match request.inputs.len() {
|
||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||
2.. => Err(InferError::GenerationError(
|
||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||
)),
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(text) => Ok(text),
|
||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn generate(
|
||||
&self,
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokens: Vec<u32>,
|
||||
top_k: u32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) {
|
||||
let tokenizer = Arc::clone(&self.tokenizer);
|
||||
let executor = Arc::clone(&self.backend);
|
||||
|
||||
// Let's push this in async context
|
||||
tokio::spawn(async move {
|
||||
// Define the generation state
|
||||
let mut generation = Generation {
|
||||
executor: executor.clone(),
|
||||
done: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
// Define the context over the generation
|
||||
// TODO(asap): Do we really need so many shared-ownership?
|
||||
let ctx = Box::new(GenerationContext {
|
||||
sender: sender.clone(),
|
||||
tokenizer,
|
||||
tokens: vec![],
|
||||
done: Arc::clone(&generation.done),
|
||||
start: None,
|
||||
queued: Instant::now(),
|
||||
});
|
||||
|
||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||
// still computation ongoing
|
||||
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
||||
let ctx_ = Box::leak(ctx);
|
||||
|
||||
// Submit the request to the batcher
|
||||
let request_id = span!(Level::DEBUG, "submit")
|
||||
.in_scope(|| async {
|
||||
let mut handle = executor.write().await;
|
||||
let request_id = handle.pin_mut().submit(
|
||||
&tokens,
|
||||
top_k as i32,
|
||||
top_p,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
seed,
|
||||
);
|
||||
|
||||
request_id
|
||||
})
|
||||
.await;
|
||||
|
||||
while let Some(_) = generation.next().await {
|
||||
let mut executor_w = executor.write().await;
|
||||
let executor = executor_w.pin_mut();
|
||||
|
||||
span!(Level::DEBUG, "decode")
|
||||
.in_scope(|| async {
|
||||
unsafe {
|
||||
executor.stream_tokens(
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||
let inner_ctx = &mut *ctx;
|
||||
|
||||
// Update the timestamp at which the request started effectively
|
||||
// Can be a bit off, would need to be before the callback, let's see
|
||||
inner_ctx.start.get_or_insert(Instant::now());
|
||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
||||
|
||||
// Ensure we are not running into errors
|
||||
let parcel = if !step.has_error {
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
let special = inner_ctx
|
||||
.tokenizer
|
||||
.get_added_vocabulary()
|
||||
.is_special_token(&text);
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
if step.is_final {
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||
queued: inner_ctx.queued,
|
||||
})
|
||||
} else {
|
||||
Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
error!("Error caught while decoding: {}", &step.error_msg);
|
||||
Err(InferError::GenerationError(step.error_msg))
|
||||
};
|
||||
|
||||
// Send the parcel to the client
|
||||
inner_ctx
|
||||
.sender
|
||||
.send(parcel)
|
||||
.expect("Failed to sent msg through the channel");
|
||||
},
|
||||
);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// "Properly" free the shared context...
|
||||
// TODO: clean that piece of sh** asap
|
||||
unsafe {
|
||||
let _ = Box::from_raw(ctx_);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRtLlmBackend {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||
// Let's add a few more validation
|
||||
let input = TensorRtLlmBackend::validate(&request)?;
|
||||
|
||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
|
||||
// Unpack parameters
|
||||
let params = &request.parameters;
|
||||
|
||||
// Preprocess the inputs to send to TRTLLM backend
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(input.as_str(), true)
|
||||
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
||||
|
||||
// Generate the response
|
||||
self.generate(
|
||||
sender,
|
||||
Vec::from(encoding.get_ids()),
|
||||
params.top_k,
|
||||
params.top_p,
|
||||
params.temperature,
|
||||
params.repetition_penalty,
|
||||
params.frequency_penalty,
|
||||
params.seed,
|
||||
);
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use text_generation_router::server;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRtLlmBackendError {
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
//
|
||||
// Created by mfuntowicz on 6/30/24.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <limits>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "backends/trtllm/include/ffi.h"
|
||||
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||
const std::string_view &engineFolder,
|
||||
const std::string_view &executorWorker
|
||||
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||
return TensorRtLlmBackend::IsReady();
|
||||
}
|
||||
|
||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
||||
float_t frequency_penalty, uint64_t seed) {
|
||||
|
||||
// This will copy all the items from the initial slice
|
||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||
return TensorRtLlmBackend::Submit(
|
||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
}
|
||||
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
const uint64_t requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||
|
||||
size_t numTokens = 0;
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
GenerationStep step;
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
|
||||
const auto token = decoded.outputTokenIds[0][0];
|
||||
const auto isFinal = decoded.isFinal;
|
||||
const auto logProb = decoded.logProbs.value()[0][0];
|
||||
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||
};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
const auto what = item.getErrorMsg();
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||
};
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||
InitializeBackend();
|
||||
|
||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
pub struct GenerationStep {
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
has_error: bool,
|
||||
error_msg: String,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type GenerationContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/src/ffi.cpp");
|
||||
|
||||
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||
type TensorRtLlmBackendImpl;
|
||||
|
||||
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
|
||||
/// * `executor_worker`: Path to the TRTLLM executor worker
|
||||
///
|
||||
/// returns: <unknown>
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// ```
|
||||
#[rust_name = "create_tensorrt_llm_backend"]
|
||||
fn CreateTensorRtLlmBackend(
|
||||
engine_folder: &str,
|
||||
executor_worker: &str,
|
||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||
|
||||
// #[rust_name = "is_ready"]
|
||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
|
||||
#[rust_name = "num_responses_ready"]
|
||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||
|
||||
#[rust_name = "submit"]
|
||||
fn Submit(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
tokens: &[u32],
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) -> u64;
|
||||
|
||||
#[rust_name = "stream_tokens"]
|
||||
unsafe fn StreamTokens(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
request_id: u64,
|
||||
ctx: *mut GenerationContext,
|
||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||
) -> usize;
|
||||
|
||||
// #[rust_name = "shutdown"]
|
||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(long, env, required = true)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
auth_token: Option<String>,
|
||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||
executor_worker: PathBuf,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
hostname,
|
||||
port,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
model_id,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
messages_api_enabled,
|
||||
max_client_batch_size,
|
||||
auth_token,
|
||||
executor_worker,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if !executor_worker.exists() {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||
"`executor_work` specified path doesn't exists: {}",
|
||||
executor_worker.display()
|
||||
)));
|
||||
}
|
||||
|
||||
// Run server
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
tokenizer_name.clone(),
|
||||
Some(FromPretrainedParameters {
|
||||
revision: revision.clone().unwrap_or(String::from("main")),
|
||||
user_agent: HashMap::new(),
|
||||
auth_token,
|
||||
}),
|
||||
)
|
||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
|
||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
messages_api_enabled,
|
||||
true,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/2/24.
|
||||
//
|
||||
#include <catch2/catch_all.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "../include/backend.h"
|
||||
|
||||
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
|
||||
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
|
||||
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
|
||||
|
||||
spdlog::info("Loading config from: {}", absolute(engines).string());
|
||||
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
[package]
|
||||
name = "text-generation-router-v3"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics-exporter-prometheus = { version = "0.12.1", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
|
@ -0,0 +1,19 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,490 @@
|
|||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||
} else {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use crate::v3::{pb, Chunk};
|
||||
use crate::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
/// Single shard Client
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
|
@ -19,6 +19,7 @@ pub struct Client {
|
|||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
|
@ -1,15 +1,23 @@
|
|||
//! Text Generation gRPC client library
|
||||
|
||||
pub mod v2;
|
||||
pub mod v3;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
pub use v3::{Chunk, Image, Input, InputChunk};
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
|
@ -63,29 +71,6 @@ impl From<Chunk> for InputChunk {
|
|||
}
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<InputChunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c.chunk {
|
||||
Some(Chunk::Text(text)) => output.push_str(text),
|
||||
Some(Chunk::Image(Image { data, mimetype })) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
// We don't create empty chunks, so this should be unreachable.
|
||||
None => unreachable!("Chunks should never be empty"),
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
|
@ -1,17 +1,17 @@
|
|||
use crate::client::{ClientError, Result};
|
||||
/// Multi shard Client
|
||||
use crate::{v3, Health, ShardInfo};
|
||||
use crate::{ClientError, Result};
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::v3::{Chunk, InfoResponse, Input};
|
||||
use crate::client::client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
use v3::client::{DecodeTimings, PrefillTimings};
|
||||
use v3::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
|
@ -35,6 +35,7 @@ impl ShardedClient {
|
|||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
|
@ -0,0 +1,141 @@
|
|||
mod backend;
|
||||
mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V3Error::Connection)?;
|
||||
|
||||
// server is running on v3
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V3Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V3Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
|
@ -0,0 +1,181 @@
|
|||
use clap::Parser;
|
||||
use text_generation_router::server;
|
||||
use text_generation_router_v3::{connect_backend, V3Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
|
@ -1,17 +1,17 @@
|
|||
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::client;
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v3::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
@ -337,8 +337,22 @@ impl State {
|
|||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
input_chunks: Some(Input {
|
||||
chunks: entry.request.inputs.clone(),
|
||||
input_chunks: Some(client::Input {
|
||||
chunks: entry
|
||||
.request
|
||||
.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|c| client::InputChunk {
|
||||
chunk: Some(match c {
|
||||
Chunk::Text(text) => client::Chunk::Text(text),
|
||||
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
|
@ -21,7 +21,7 @@ float-ord = "0.3.2"
|
|||
serde = {version = "1.0.188", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
tabled = "0.14.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
text-generation-client = { path = "../backends/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
|
|
|
@ -7,18 +7,11 @@ edition.workspace = true
|
|||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/v2/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v2/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/generate.proto"], &["../../proto"])
|
||||
.map_err(|e| match e.kind(){
|
||||
std::io::ErrorKind::NotFound => {panic!("`protoc` not found, install libprotoc")},
|
||||
std::io::ErrorKind::Other => {panic!("`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases")},
|
||||
e => {e}
|
||||
}).unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
fs::create_dir_all("src/v3/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/v3/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -0,0 +1,647 @@
|
|||
// This file is @generated by prost-build.
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthResponse {}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoResponse {
|
||||
#[prost(bool, tag = "1")]
|
||||
pub requires_padding: bool,
|
||||
#[prost(string, tag = "2")]
|
||||
pub dtype: ::prost::alloc::string::String,
|
||||
#[prost(string, tag = "3")]
|
||||
pub device_type: ::prost::alloc::string::String,
|
||||
#[prost(uint32, optional, tag = "4")]
|
||||
pub window_size: ::core::option::Option<u32>,
|
||||
#[prost(uint32, tag = "5")]
|
||||
pub speculate: u32,
|
||||
}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryResponse {
|
||||
/// / Other shards urls
|
||||
#[prost(string, repeated, tag = "1")]
|
||||
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheRequest {
|
||||
/// / Optional batch id
|
||||
#[prost(uint64, optional, tag = "1")]
|
||||
pub id: ::core::option::Option<u64>,
|
||||
}
|
||||
/// / Empty response
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheResponse {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct NextTokenChooserParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
#[prost(float, tag = "1")]
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub top_k: u32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "3")]
|
||||
pub top_p: f32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "4")]
|
||||
pub typical_p: f32,
|
||||
/// / apply sampling on the logits
|
||||
#[prost(bool, tag = "5")]
|
||||
pub do_sample: bool,
|
||||
/// / random seed for sampling
|
||||
#[prost(uint64, tag = "6")]
|
||||
pub seed: u64,
|
||||
/// / repetition penalty
|
||||
#[prost(float, tag = "7")]
|
||||
pub repetition_penalty: f32,
|
||||
/// / frequency penalty
|
||||
#[prost(float, tag = "9")]
|
||||
pub frequency_penalty: f32,
|
||||
/// / token watermarking using "A Watermark for Large Language Models"
|
||||
#[prost(bool, tag = "8")]
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
#[prost(string, tag = "10")]
|
||||
pub grammar: ::prost::alloc::string::String,
|
||||
/// / grammar type
|
||||
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||
pub grammar_type: i32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct StoppingCriteriaParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
#[prost(uint32, tag = "1")]
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
#[prost(string, repeated, tag = "2")]
|
||||
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / Ignore end of sequence token
|
||||
/// / used for benchmarking
|
||||
#[prost(bool, tag = "3")]
|
||||
pub ignore_eos_token: bool,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Request {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / The generation context
|
||||
#[prost(string, tag = "2")]
|
||||
pub inputs: ::prost::alloc::string::String,
|
||||
/// / Context truncation
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub truncate: u32,
|
||||
/// / Next Token Chooser Parameters
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||
/// / Stopping Criteria Parameters
|
||||
#[prost(message, optional, tag = "5")]
|
||||
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||
/// / Return prefill logprobs
|
||||
#[prost(bool, tag = "6")]
|
||||
pub prefill_logprobs: bool,
|
||||
/// / Return most likely n tokens
|
||||
#[prost(uint32, tag = "7")]
|
||||
pub top_n_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Batch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests
|
||||
#[prost(message, repeated, tag = "2")]
|
||||
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct CachedBatch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests ids
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct GeneratedText {
|
||||
/// / Output
|
||||
#[prost(string, tag = "1")]
|
||||
pub text: ::prost::alloc::string::String,
|
||||
/// / Number of generated tokens
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub generated_tokens: u32,
|
||||
/// / Finish reason
|
||||
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||
pub finish_reason: i32,
|
||||
/// / Seed
|
||||
#[prost(uint64, optional, tag = "4")]
|
||||
pub seed: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Tokens {
|
||||
/// / Token IDs
|
||||
#[prost(uint32, repeated, tag = "1")]
|
||||
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||
/// / Logprobs
|
||||
#[prost(float, repeated, tag = "2")]
|
||||
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||
/// / tokens
|
||||
#[prost(string, repeated, tag = "3")]
|
||||
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / special
|
||||
#[prost(bool, repeated, tag = "4")]
|
||||
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Generation {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub request_id: u64,
|
||||
/// / Prefill tokens (optional)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||
#[prost(message, optional, tag = "3")]
|
||||
pub tokens: ::core::option::Option<Tokens>,
|
||||
/// / Complete generated text
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||
/// / Top tokens
|
||||
#[prost(message, repeated, tag = "5")]
|
||||
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchRequest {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub batch_id: u64,
|
||||
/// / Requests to keep
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchResponse {
|
||||
/// / Filtered Batch (cached)
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillRequest {
|
||||
/// / Batch
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillResponse {
|
||||
/// / Generation
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeRequest {
|
||||
/// / Cached batches
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeResponse {
|
||||
/// / Decodes
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
/// / Concatenate elapsed time in nanoseconds
|
||||
#[prost(uint64, optional, tag = "6")]
|
||||
pub concat_ns: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupRequest {
|
||||
/// / Batch to warmup on
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub max_input_length: u32,
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub max_prefill_tokens: u32,
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_total_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupResponse {
|
||||
/// / Maximum number of tokens supported by the model
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum GrammarType {
|
||||
None = 0,
|
||||
Json = 1,
|
||||
Regex = 2,
|
||||
}
|
||||
impl GrammarType {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum FinishReason {
|
||||
Length = 0,
|
||||
EosToken = 1,
|
||||
StopSequence = 2,
|
||||
}
|
||||
impl FinishReason {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Generated client implementations.
|
||||
pub mod text_generation_service_client {
|
||||
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||
use tonic::codegen::*;
|
||||
use tonic::codegen::http::Uri;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextGenerationServiceClient<T> {
|
||||
inner: tonic::client::Grpc<T>,
|
||||
}
|
||||
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||
/// Attempt to create a new client by connecting to a given endpoint.
|
||||
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||
where
|
||||
D: TryInto<tonic::transport::Endpoint>,
|
||||
D::Error: Into<StdError>,
|
||||
{
|
||||
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||
Ok(Self::new(conn))
|
||||
}
|
||||
}
|
||||
impl<T> TextGenerationServiceClient<T>
|
||||
where
|
||||
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||
T::Error: Into<StdError>,
|
||||
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||
{
|
||||
pub fn new(inner: T) -> Self {
|
||||
let inner = tonic::client::Grpc::new(inner);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_interceptor<F>(
|
||||
inner: T,
|
||||
interceptor: F,
|
||||
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||
where
|
||||
F: tonic::service::Interceptor,
|
||||
T::ResponseBody: Default,
|
||||
T: tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
Response = http::Response<
|
||||
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||
>,
|
||||
>,
|
||||
<T as tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
>>::Error: Into<StdError> + Send + Sync,
|
||||
{
|
||||
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||
}
|
||||
/// Compress requests with the given encoding.
|
||||
///
|
||||
/// This requires the server to support it otherwise it might respond with an
|
||||
/// error.
|
||||
#[must_use]
|
||||
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.send_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Enable decompressing responses.
|
||||
#[must_use]
|
||||
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.accept_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of a decoded message.
|
||||
///
|
||||
/// Default: `4MB`
|
||||
#[must_use]
|
||||
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_decoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of an encoded message.
|
||||
///
|
||||
/// Default: `usize::MAX`
|
||||
#[must_use]
|
||||
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_encoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// / Model Info
|
||||
pub async fn info(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Info",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Info"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Service discovery
|
||||
pub async fn service_discovery(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ServiceDiscoveryResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/ServiceDiscovery",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new(
|
||||
"generate.v2.TextGenerationService",
|
||||
"ServiceDiscovery",
|
||||
),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Empties batch cache
|
||||
pub async fn clear_cache(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ClearCacheResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/ClearCache",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v2.TextGenerationService", "ClearCache"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Remove requests from a cached batch
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::FilterBatchResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/FilterBatch",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v2.TextGenerationService", "FilterBatch"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Warmup the model and compute max cache size
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Warmup",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Warmup"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Prefill batch and decode first token
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::PrefillResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Prefill",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Prefill"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Decode token for a list of prefilled batches
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Decode",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Decode"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Health check
|
||||
pub async fn health(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v2.TextGenerationService/Health",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v2.TextGenerationService", "Health"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// This file is @generated by prost-build.
|
||||
pub mod generate {
|
||||
pub mod v2 {
|
||||
include!("generate.v2.rs");
|
||||
}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters, Tokens,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -0,0 +1,697 @@
|
|||
// This file is @generated by prost-build.
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct HealthResponse {}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InfoResponse {
|
||||
#[prost(bool, tag = "1")]
|
||||
pub requires_padding: bool,
|
||||
#[prost(string, tag = "2")]
|
||||
pub dtype: ::prost::alloc::string::String,
|
||||
#[prost(string, tag = "3")]
|
||||
pub device_type: ::prost::alloc::string::String,
|
||||
#[prost(uint32, optional, tag = "4")]
|
||||
pub window_size: ::core::option::Option<u32>,
|
||||
#[prost(uint32, tag = "5")]
|
||||
pub speculate: u32,
|
||||
}
|
||||
/// / Empty request
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryRequest {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ServiceDiscoveryResponse {
|
||||
/// / Other shards urls
|
||||
#[prost(string, repeated, tag = "1")]
|
||||
pub urls: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheRequest {
|
||||
/// / Optional batch id
|
||||
#[prost(uint64, optional, tag = "1")]
|
||||
pub id: ::core::option::Option<u64>,
|
||||
}
|
||||
/// / Empty response
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct ClearCacheResponse {}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Image {
|
||||
/// / Binary image data.
|
||||
#[prost(bytes = "vec", tag = "1")]
|
||||
pub data: ::prost::alloc::vec::Vec<u8>,
|
||||
/// / Image MIME type.
|
||||
#[prost(string, tag = "2")]
|
||||
pub mimetype: ::prost::alloc::string::String,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct InputChunk {
|
||||
#[prost(oneof = "input_chunk::Chunk", tags = "1, 2")]
|
||||
pub chunk: ::core::option::Option<input_chunk::Chunk>,
|
||||
}
|
||||
/// Nested message and enum types in `InputChunk`.
|
||||
pub mod input_chunk {
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Oneof)]
|
||||
pub enum Chunk {
|
||||
/// / Plain text data
|
||||
#[prost(string, tag = "1")]
|
||||
Text(::prost::alloc::string::String),
|
||||
/// / Image data
|
||||
#[prost(message, tag = "2")]
|
||||
Image(super::Image),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Input {
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub chunks: ::prost::alloc::vec::Vec<InputChunk>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct NextTokenChooserParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
#[prost(float, tag = "1")]
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub top_k: u32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "3")]
|
||||
pub top_p: f32,
|
||||
/// / restricting to top tokens summing to prob_cut_off <= prob_cut_off
|
||||
#[prost(float, tag = "4")]
|
||||
pub typical_p: f32,
|
||||
/// / apply sampling on the logits
|
||||
#[prost(bool, tag = "5")]
|
||||
pub do_sample: bool,
|
||||
/// / random seed for sampling
|
||||
#[prost(uint64, tag = "6")]
|
||||
pub seed: u64,
|
||||
/// / repetition penalty
|
||||
#[prost(float, tag = "7")]
|
||||
pub repetition_penalty: f32,
|
||||
/// / frequency penalty
|
||||
#[prost(float, tag = "9")]
|
||||
pub frequency_penalty: f32,
|
||||
/// / token watermarking using "A Watermark for Large Language Models"
|
||||
#[prost(bool, tag = "8")]
|
||||
pub watermark: bool,
|
||||
/// / grammar (applied if not empty)
|
||||
#[prost(string, tag = "10")]
|
||||
pub grammar: ::prost::alloc::string::String,
|
||||
/// / grammar type
|
||||
#[prost(enumeration = "GrammarType", tag = "11")]
|
||||
pub grammar_type: i32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct StoppingCriteriaParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
#[prost(uint32, tag = "1")]
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
#[prost(string, repeated, tag = "2")]
|
||||
pub stop_sequences: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / Ignore end of sequence token
|
||||
/// / used for benchmarking
|
||||
#[prost(bool, tag = "3")]
|
||||
pub ignore_eos_token: bool,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Request {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / The generation context as chunks
|
||||
#[prost(message, optional, tag = "8")]
|
||||
pub input_chunks: ::core::option::Option<Input>,
|
||||
/// / The generation context, stringified input_chunks
|
||||
#[prost(string, tag = "2")]
|
||||
pub inputs: ::prost::alloc::string::String,
|
||||
/// / Context truncation
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub truncate: u32,
|
||||
/// / Next Token Chooser Parameters
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub parameters: ::core::option::Option<NextTokenChooserParameters>,
|
||||
/// / Stopping Criteria Parameters
|
||||
#[prost(message, optional, tag = "5")]
|
||||
pub stopping_parameters: ::core::option::Option<StoppingCriteriaParameters>,
|
||||
/// / Return prefill logprobs
|
||||
#[prost(bool, tag = "6")]
|
||||
pub prefill_logprobs: bool,
|
||||
/// / Return most likely n tokens
|
||||
#[prost(uint32, tag = "7")]
|
||||
pub top_n_tokens: u32,
|
||||
/// / Paged attention blocks
|
||||
#[prost(uint32, repeated, tag = "9")]
|
||||
pub blocks: ::prost::alloc::vec::Vec<u32>,
|
||||
/// / Paged attention slots
|
||||
#[prost(uint32, repeated, tag = "10")]
|
||||
pub slots: ::prost::alloc::vec::Vec<u32>,
|
||||
/// / LORA adapter index
|
||||
#[prost(string, optional, tag = "11")]
|
||||
pub adapter_id: ::core::option::Option<::prost::alloc::string::String>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Batch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests
|
||||
#[prost(message, repeated, tag = "2")]
|
||||
pub requests: ::prost::alloc::vec::Vec<Request>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
/// / Maximum number of Paged Attention blocks
|
||||
#[prost(uint32, tag = "5")]
|
||||
pub max_blocks: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct CachedBatch {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub id: u64,
|
||||
/// / Individual requests ids
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
/// / Batch size (==len(requests))
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub size: u32,
|
||||
/// / Maximum number of tokens this batch will grow to
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct GeneratedText {
|
||||
/// / Output
|
||||
#[prost(string, tag = "1")]
|
||||
pub text: ::prost::alloc::string::String,
|
||||
/// / Number of generated tokens
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub generated_tokens: u32,
|
||||
/// / Finish reason
|
||||
#[prost(enumeration = "FinishReason", tag = "3")]
|
||||
pub finish_reason: i32,
|
||||
/// / Seed
|
||||
#[prost(uint64, optional, tag = "4")]
|
||||
pub seed: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Tokens {
|
||||
/// / Token IDs
|
||||
#[prost(uint32, repeated, tag = "1")]
|
||||
pub ids: ::prost::alloc::vec::Vec<u32>,
|
||||
/// / Logprobs
|
||||
#[prost(float, repeated, tag = "2")]
|
||||
pub logprobs: ::prost::alloc::vec::Vec<f32>,
|
||||
/// / tokens
|
||||
#[prost(string, repeated, tag = "3")]
|
||||
pub texts: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
|
||||
/// / special
|
||||
#[prost(bool, repeated, tag = "4")]
|
||||
pub is_special: ::prost::alloc::vec::Vec<bool>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct Generation {
|
||||
/// / Request ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub request_id: u64,
|
||||
/// / Prefill tokens (optional)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub prefill_tokens: ::core::option::Option<Tokens>,
|
||||
#[prost(message, optional, tag = "3")]
|
||||
pub tokens: ::core::option::Option<Tokens>,
|
||||
/// / Complete generated text
|
||||
#[prost(message, optional, tag = "4")]
|
||||
pub generated_text: ::core::option::Option<GeneratedText>,
|
||||
/// / Top tokens
|
||||
#[prost(message, repeated, tag = "5")]
|
||||
pub top_tokens: ::prost::alloc::vec::Vec<Tokens>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchRequest {
|
||||
/// / Batch ID
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub batch_id: u64,
|
||||
/// / Requests to keep
|
||||
#[prost(uint64, repeated, tag = "2")]
|
||||
pub request_ids: ::prost::alloc::vec::Vec<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct FilterBatchResponse {
|
||||
/// / Filtered Batch (cached)
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillRequest {
|
||||
/// / Batch
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct PrefillResponse {
|
||||
/// / Generation
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeRequest {
|
||||
/// / Cached batches
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub batches: ::prost::alloc::vec::Vec<CachedBatch>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct DecodeResponse {
|
||||
/// / Decodes
|
||||
#[prost(message, repeated, tag = "1")]
|
||||
pub generations: ::prost::alloc::vec::Vec<Generation>,
|
||||
/// / Next batch (cached)
|
||||
#[prost(message, optional, tag = "2")]
|
||||
pub batch: ::core::option::Option<CachedBatch>,
|
||||
/// / Forward elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "3")]
|
||||
pub forward_ns: u64,
|
||||
/// / Decode elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "4")]
|
||||
pub decode_ns: u64,
|
||||
/// / Total elapsed time in nanoseconds
|
||||
#[prost(uint64, tag = "5")]
|
||||
pub total_ns: u64,
|
||||
/// / Concatenate elapsed time in nanoseconds
|
||||
#[prost(uint64, optional, tag = "6")]
|
||||
pub concat_ns: ::core::option::Option<u64>,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupRequest {
|
||||
/// / Batch to warmup on
|
||||
#[prost(message, optional, tag = "1")]
|
||||
pub batch: ::core::option::Option<Batch>,
|
||||
#[prost(uint32, tag = "2")]
|
||||
pub max_input_length: u32,
|
||||
#[prost(uint32, tag = "3")]
|
||||
pub max_prefill_tokens: u32,
|
||||
#[prost(uint32, tag = "4")]
|
||||
pub max_total_tokens: u32,
|
||||
}
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[derive(Clone, PartialEq, ::prost::Message)]
|
||||
pub struct WarmupResponse {
|
||||
/// / Maximum number of tokens supported by the model
|
||||
#[prost(uint32, optional, tag = "1")]
|
||||
pub max_supported_total_tokens: ::core::option::Option<u32>,
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum GrammarType {
|
||||
None = 0,
|
||||
Json = 1,
|
||||
Regex = 2,
|
||||
}
|
||||
impl GrammarType {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
GrammarType::None => "GRAMMAR_TYPE_NONE",
|
||||
GrammarType::Json => "GRAMMAR_TYPE_JSON",
|
||||
GrammarType::Regex => "GRAMMAR_TYPE_REGEX",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"GRAMMAR_TYPE_NONE" => Some(Self::None),
|
||||
"GRAMMAR_TYPE_JSON" => Some(Self::Json),
|
||||
"GRAMMAR_TYPE_REGEX" => Some(Self::Regex),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
|
||||
#[repr(i32)]
|
||||
pub enum FinishReason {
|
||||
Length = 0,
|
||||
EosToken = 1,
|
||||
StopSequence = 2,
|
||||
}
|
||||
impl FinishReason {
|
||||
/// String value of the enum field names used in the ProtoBuf definition.
|
||||
///
|
||||
/// The values are not transformed in any way and thus are considered stable
|
||||
/// (if the ProtoBuf definition does not change) and safe for programmatic use.
|
||||
pub fn as_str_name(&self) -> &'static str {
|
||||
match self {
|
||||
FinishReason::Length => "FINISH_REASON_LENGTH",
|
||||
FinishReason::EosToken => "FINISH_REASON_EOS_TOKEN",
|
||||
FinishReason::StopSequence => "FINISH_REASON_STOP_SEQUENCE",
|
||||
}
|
||||
}
|
||||
/// Creates an enum from field names used in the ProtoBuf definition.
|
||||
pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
|
||||
match value {
|
||||
"FINISH_REASON_LENGTH" => Some(Self::Length),
|
||||
"FINISH_REASON_EOS_TOKEN" => Some(Self::EosToken),
|
||||
"FINISH_REASON_STOP_SEQUENCE" => Some(Self::StopSequence),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Generated client implementations.
|
||||
pub mod text_generation_service_client {
|
||||
#![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)]
|
||||
use tonic::codegen::*;
|
||||
use tonic::codegen::http::Uri;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextGenerationServiceClient<T> {
|
||||
inner: tonic::client::Grpc<T>,
|
||||
}
|
||||
impl TextGenerationServiceClient<tonic::transport::Channel> {
|
||||
/// Attempt to create a new client by connecting to a given endpoint.
|
||||
pub async fn connect<D>(dst: D) -> Result<Self, tonic::transport::Error>
|
||||
where
|
||||
D: TryInto<tonic::transport::Endpoint>,
|
||||
D::Error: Into<StdError>,
|
||||
{
|
||||
let conn = tonic::transport::Endpoint::new(dst)?.connect().await?;
|
||||
Ok(Self::new(conn))
|
||||
}
|
||||
}
|
||||
impl<T> TextGenerationServiceClient<T>
|
||||
where
|
||||
T: tonic::client::GrpcService<tonic::body::BoxBody>,
|
||||
T::Error: Into<StdError>,
|
||||
T::ResponseBody: Body<Data = Bytes> + Send + 'static,
|
||||
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
|
||||
{
|
||||
pub fn new(inner: T) -> Self {
|
||||
let inner = tonic::client::Grpc::new(inner);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_origin(inner: T, origin: Uri) -> Self {
|
||||
let inner = tonic::client::Grpc::with_origin(inner, origin);
|
||||
Self { inner }
|
||||
}
|
||||
pub fn with_interceptor<F>(
|
||||
inner: T,
|
||||
interceptor: F,
|
||||
) -> TextGenerationServiceClient<InterceptedService<T, F>>
|
||||
where
|
||||
F: tonic::service::Interceptor,
|
||||
T::ResponseBody: Default,
|
||||
T: tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
Response = http::Response<
|
||||
<T as tonic::client::GrpcService<tonic::body::BoxBody>>::ResponseBody,
|
||||
>,
|
||||
>,
|
||||
<T as tonic::codegen::Service<
|
||||
http::Request<tonic::body::BoxBody>,
|
||||
>>::Error: Into<StdError> + Send + Sync,
|
||||
{
|
||||
TextGenerationServiceClient::new(InterceptedService::new(inner, interceptor))
|
||||
}
|
||||
/// Compress requests with the given encoding.
|
||||
///
|
||||
/// This requires the server to support it otherwise it might respond with an
|
||||
/// error.
|
||||
#[must_use]
|
||||
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.send_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Enable decompressing responses.
|
||||
#[must_use]
|
||||
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
|
||||
self.inner = self.inner.accept_compressed(encoding);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of a decoded message.
|
||||
///
|
||||
/// Default: `4MB`
|
||||
#[must_use]
|
||||
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_decoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// Limits the maximum size of an encoded message.
|
||||
///
|
||||
/// Default: `usize::MAX`
|
||||
#[must_use]
|
||||
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
|
||||
self.inner = self.inner.max_encoding_message_size(limit);
|
||||
self
|
||||
}
|
||||
/// / Model Info
|
||||
pub async fn info(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::InfoRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::InfoResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/Info",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Info"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Service discovery
|
||||
pub async fn service_discovery(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ServiceDiscoveryRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ServiceDiscoveryResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/ServiceDiscovery",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new(
|
||||
"generate.v3.TextGenerationService",
|
||||
"ServiceDiscovery",
|
||||
),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Empties batch cache
|
||||
pub async fn clear_cache(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::ClearCacheRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::ClearCacheResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/ClearCache",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v3.TextGenerationService", "ClearCache"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Remove requests from a cached batch
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::FilterBatchRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::FilterBatchResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/FilterBatch",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(
|
||||
GrpcMethod::new("generate.v3.TextGenerationService", "FilterBatch"),
|
||||
);
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Warmup the model and compute max cache size
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::WarmupRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::WarmupResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/Warmup",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Warmup"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Prefill batch and decode first token
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::PrefillRequest>,
|
||||
) -> std::result::Result<
|
||||
tonic::Response<super::PrefillResponse>,
|
||||
tonic::Status,
|
||||
> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/Prefill",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Prefill"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Decode token for a list of prefilled batches
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::DecodeRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::DecodeResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/Decode",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Decode"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
/// / Health check
|
||||
pub async fn health(
|
||||
&mut self,
|
||||
request: impl tonic::IntoRequest<super::HealthRequest>,
|
||||
) -> std::result::Result<tonic::Response<super::HealthResponse>, tonic::Status> {
|
||||
self.inner
|
||||
.ready()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tonic::Status::new(
|
||||
tonic::Code::Unknown,
|
||||
format!("Service was not ready: {}", e.into()),
|
||||
)
|
||||
})?;
|
||||
let codec = tonic::codec::ProstCodec::default();
|
||||
let path = http::uri::PathAndQuery::from_static(
|
||||
"/generate.v3.TextGenerationService/Health",
|
||||
);
|
||||
let mut req = request.into_request();
|
||||
req.extensions_mut()
|
||||
.insert(GrpcMethod::new("generate.v3.TextGenerationService", "Health"));
|
||||
self.inner.unary(req, path, codec).await
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
// This file is @generated by prost-build.
|
||||
pub mod generate {
|
||||
pub mod v3 {
|
||||
include!("generate.v3.rs");
|
||||
}
|
||||
}
|
|
@ -1,528 +1,83 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v3::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use crate::infer::InferError;
|
||||
use crate::{ChatTemplateInputs, GrammarType, Message, MessageChunk, Text, TextMessage};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
pub(crate) struct SchedulerV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
impl SchedulerV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
template: String,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
generation_health,
|
||||
));
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
pub(crate) fn apply(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let input_length = request.input_length;
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((
|
||||
permit,
|
||||
input_length,
|
||||
UnboundedReceiverStream::new(response_rx),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text(Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
}));
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
||||
let v3_finish_reason =
|
||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
// tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::raise_exception;
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::{ChatTemplateInputs, TextMessage};
|
||||
use minijinja::Environment;
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::Health;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct HealthCheck {
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl HealthCheck {
|
||||
pub(crate) fn new(
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
|
@ -1,23 +1,18 @@
|
|||
mod health;
|
||||
pub(crate) mod v2;
|
||||
pub(crate) mod v3;
|
||||
|
||||
pub(crate) use health::HealthCheck;
|
||||
// pub(crate) mod v2;
|
||||
mod chat_template;
|
||||
pub mod tool_grammar;
|
||||
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::GrammarType;
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
||||
};
|
||||
use crate::{
|
||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use minijinja::ErrorKind;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
|
@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
pub(crate) trait Scheduler {
|
||||
#[async_trait]
|
||||
pub trait Backend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Result<GenerateStreamResponse, InferError>;
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool;
|
||||
}
|
||||
|
||||
/// Inference struct
|
||||
|
@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
|
|||
pub struct Infer {
|
||||
/// Validation
|
||||
validation: Validation,
|
||||
/// Request scheduler
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Backend health
|
||||
backend_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
validation: Validation,
|
||||
max_concurrent_requests: usize,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
|
@ -71,18 +70,22 @@ impl Infer {
|
|||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
||||
// Backend health
|
||||
let backend_health = Arc::new(AtomicBool::new(false));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
scheduler,
|
||||
backend: Arc::new(backend),
|
||||
chat_template,
|
||||
limit_concurrent_requests: semaphore,
|
||||
backend_health,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
|
@ -103,7 +106,10 @@ impl Infer {
|
|||
err
|
||||
})?;
|
||||
|
||||
self.scheduler.schedule(valid_request, permit)
|
||||
let input_length = valid_request.input_length;
|
||||
let generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
Ok((permit, input_length, generation_stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
|
@ -155,7 +161,7 @@ impl Infer {
|
|||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
||||
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
|
@ -165,6 +171,8 @@ impl Infer {
|
|||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
|
||||
let mut stream = Box::pin(stream);
|
||||
|
||||
// Iterate on stream
|
||||
while let Some(response) = stream.next().await {
|
||||
match response? {
|
||||
|
@ -256,207 +264,15 @@ impl Infer {
|
|||
let best_response = infer_responses.remove(max_index);
|
||||
Ok((best_response, infer_responses))
|
||||
}
|
||||
}
|
||||
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
fn new(
|
||||
template: String,
|
||||
bos_token: Option<TokenizerConfigToken>,
|
||||
eos_token: Option<TokenizerConfigToken>,
|
||||
) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
// find a tool by name
|
||||
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == name)
|
||||
.cloned()
|
||||
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Some(tools))
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn health(&self) -> bool {
|
||||
let health = self
|
||||
.backend
|
||||
.health(self.backend_health.load(Ordering::SeqCst))
|
||||
.await;
|
||||
self.backend_health.store(health, Ordering::SeqCst);
|
||||
health
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
|
|||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub(crate) text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) finish_reason: FinishReason,
|
||||
pub(crate) seed: Option<u64>,
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
pub generated_tokens: u32,
|
||||
pub finish_reason: FinishReason,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
pub enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(Vec<PrefillToken>),
|
||||
// Intermediate messages
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
use crate::infer::InferError;
|
||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolType, Tools};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(crate) struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: Option<ToolType>,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
|
||||
// let tool_prompt = tool_prompt.unwrap_or_default();
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![req_tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == *name)
|
||||
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
|
||||
.clone()]
|
||||
}
|
||||
ToolType::OneOf => req_tools.to_owned(),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(Some(tools));
|
||||
}
|
||||
// Err(InferError::ToolError("No tools provided".to_string()))
|
||||
Ok(None)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::SchedulerV2;
|
||||
pub(crate) use scheduler::BackendV2;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
|
@ -18,14 +18,14 @@ use tokio::time::Instant;
|
|||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub(crate) struct SchedulerV2 {
|
||||
pub(crate) struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl SchedulerV2 {
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
|
@ -69,7 +69,7 @@ impl SchedulerV2 {
|
|||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV2 {
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
mod block_allocator;
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::SchedulerV3;
|
|
@ -1,11 +1,12 @@
|
|||
/// Text Generation Inference Webserver
|
||||
pub mod config;
|
||||
mod infer;
|
||||
pub mod infer;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
pub mod validation;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
mod kserve;
|
||||
pub mod logging;
|
||||
|
||||
pub mod usage_stats;
|
||||
|
||||
|
@ -148,12 +149,13 @@ pub struct Info {
|
|||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
pub model_sha: Option<String>,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
// #[schema(example = "torch.float16")]
|
||||
// pub model_dtype: String,
|
||||
// #[schema(example = "cuda")]
|
||||
// pub model_device_type: String,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
|
||||
/// Router Parameters
|
||||
#[schema(example = "128")]
|
||||
pub max_concurrent_requests: usize,
|
||||
|
@ -165,18 +167,11 @@ pub struct Info {
|
|||
pub max_input_tokens: usize,
|
||||
#[schema(example = "2048")]
|
||||
pub max_total_tokens: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
#[schema(example = "32")]
|
||||
pub max_client_batch_size: usize,
|
||||
|
||||
/// Router Info
|
||||
#[schema(example = "text-generation-router")]
|
||||
pub router: &'static str,
|
||||
|
@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct PrefillToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
pub struct Token {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
#[schema(example = "false")]
|
||||
special: bool,
|
||||
pub special: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
|
@ -1102,7 +1097,7 @@ pub struct SimpleToken {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub(crate) enum FinishReason {
|
||||
pub enum FinishReason {
|
||||
#[schema(rename = "length")]
|
||||
Length,
|
||||
#[serde(rename = "eos_token")]
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||
|
||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - otlp_service_name service name to appear in APM
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_ansi(ansi)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
||||
false => fmt_layer.boxed(),
|
||||
};
|
||||
layers.push(fmt_layer);
|
||||
|
||||
// OpenTelemetry tracing layer
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let tracer = opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(
|
||||
opentelemetry_otlp::new_exporter()
|
||||
.tonic()
|
||||
.with_endpoint(otlp_endpoint),
|
||||
)
|
||||
.with_trace_config(
|
||||
trace::config()
|
||||
.with_resource(Resource::new(vec![KeyValue::new(
|
||||
"service.name",
|
||||
otlp_service_name,
|
||||
)]))
|
||||
.with_sampler(Sampler::AlwaysOn),
|
||||
)
|
||||
.install_batch(opentelemetry::runtime::Tokio);
|
||||
|
||||
if let Ok(tracer) = tracer {
|
||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
||||
};
|
||||
}
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let varname = "LOG_LEVEL";
|
||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||
// Override to avoid simple logs to be spammed with tokio level informations
|
||||
let log_level = match &log_level[..] {
|
||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||
log_level => log_level,
|
||||
};
|
||||
EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.parse_lossy(log_level)
|
||||
} else {
|
||||
EnvFilter::new("info")
|
||||
};
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(layers)
|
||||
.init();
|
||||
}
|
|
@ -1,9 +1,7 @@
|
|||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::v2::SchedulerV2;
|
||||
use crate::infer::v3::SchedulerV3;
|
||||
use crate::infer::{HealthCheck, Scheduler};
|
||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::infer::tool_grammar::ToolGrammar;
|
||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||
#[cfg(feature = "kserve")]
|
||||
use crate::kserve::{
|
||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||
|
@ -27,7 +25,7 @@ use crate::{
|
|||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
|
@ -38,13 +36,15 @@ use futures::stream::{FuturesOrdered, FuturesUnordered};
|
|||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{v2, v3, ClientError, ShardInfo};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
|
@ -124,12 +124,10 @@ responses(
|
|||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(health))]
|
||||
#[instrument(skip(infer))]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<HealthCheck>,
|
||||
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match infer.health().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
|
@ -430,8 +428,9 @@ async fn generate_stream_internal(
|
|||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, _input_length, mut response_stream)) => {
|
||||
Ok((_permit, _input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
index += 1;
|
||||
|
@ -1399,22 +1398,13 @@ pub(crate) struct ComputeType(String);
|
|||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
master_shard_uds_path: String,
|
||||
model_info: HubModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
|
@ -1518,140 +1508,185 @@ pub async fn run(
|
|||
std::process::exit(0);
|
||||
}
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||
Arc<dyn Scheduler + Send + Sync>,
|
||||
HealthCheck,
|
||||
ShardInfo,
|
||||
u32,
|
||||
) = {
|
||||
// Helper function to check both v2 and v3
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||
);
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_token(authorization_token);
|
||||
|
||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||
builder = builder.with_cache_dir(cache_dir.into());
|
||||
}
|
||||
|
||||
builder
|
||||
};
|
||||
|
||||
// Decide if we need to use the API based on the revision and local path
|
||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
#[derive(Clone)]
|
||||
enum Type {
|
||||
Api(Api),
|
||||
Cache(Cache),
|
||||
None,
|
||||
}
|
||||
let api = if use_api {
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = Cache::default();
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Type::Api(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
Type::None
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
|
||||
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
|
||||
Ok(mut sharded_client) => {
|
||||
// server is running on v3
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(WebServerError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(WebServerError::Warmup)?,
|
||||
)?;
|
||||
|
||||
let health_ext =
|
||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||
let scheduler = Arc::new(SchedulerV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
));
|
||||
tracing::info!("Using scheduler V3");
|
||||
|
||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||
}
|
||||
Err(_) => {
|
||||
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(WebServerError::Connection)?;
|
||||
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(WebServerError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(WebServerError::Warmup)?,
|
||||
)?;
|
||||
|
||||
let health_ext =
|
||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||
let scheduler = Arc::new(SchedulerV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
));
|
||||
tracing::info!("Using scheduler V2");
|
||||
|
||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Type::None
|
||||
};
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
} else {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
None
|
||||
};
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
(
|
||||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
let tokenizer: Option<Tokenizer> =
|
||||
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok());
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
});
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
let processor_config = processor_config_filename
|
||||
.and_then(HubProcessorConfig::from_file)
|
||||
.unwrap_or_default();
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||
}
|
||||
|
||||
// if pipeline-tag == text-generation we default to return_full_text = true
|
||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||
None => {
|
||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||
true
|
||||
}
|
||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||
};
|
||||
|
||||
// Determine the server port based on the feature and environment variable.
|
||||
let port = if cfg!(feature = "google") {
|
||||
std::env::var("AIP_HTTP_PORT")
|
||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||
.unwrap_or(port)
|
||||
} else {
|
||||
port
|
||||
};
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
Ok(ip) => SocketAddr::new(ip, port),
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||
}
|
||||
};
|
||||
|
||||
// Create state
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
|
@ -1666,7 +1701,7 @@ pub async fn run(
|
|||
);
|
||||
|
||||
let infer = Infer::new(
|
||||
scheduler,
|
||||
backend,
|
||||
validation,
|
||||
max_concurrent_requests,
|
||||
tokenizer_config,
|
||||
|
@ -1703,8 +1738,8 @@ pub async fn run(
|
|||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||
// Speculated tokens buckets
|
||||
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
// let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||
|
||||
// Prometheus handler
|
||||
let builder = PrometheusBuilder::new()
|
||||
|
@ -1717,9 +1752,9 @@ pub async fn run(
|
|||
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
||||
.unwrap()
|
||||
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
||||
.unwrap()
|
||||
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||
.unwrap();
|
||||
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||
// .unwrap();
|
||||
let prom_handle = builder
|
||||
.install_recorder()
|
||||
.expect("failed to install metrics recorder");
|
||||
|
@ -1735,18 +1770,18 @@ pub async fn run(
|
|||
let info = Info {
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
// model_dtype: shard_info.dtype,
|
||||
// model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
// waiting_served_ratio,
|
||||
// max_batch_total_tokens,
|
||||
// max_waiting_tokens,
|
||||
// max_batch_size,
|
||||
validation_workers,
|
||||
max_client_batch_size,
|
||||
router: env!("CARGO_PKG_NAME"),
|
||||
|
@ -1907,7 +1942,6 @@ pub async fn run(
|
|||
// add layers after routes
|
||||
app = app
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(compute_type))
|
||||
|
@ -1945,6 +1979,68 @@ pub async fn run(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||
let response = api.info_request().send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let hub_model_info: HubModelInfo =
|
||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||
if let Some(sha) = &hub_model_info.sha {
|
||||
tracing::info!(
|
||||
"Serving revision {sha} of model {}",
|
||||
hub_model_info.model_id
|
||||
);
|
||||
}
|
||||
Some(hub_model_info)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// get base tokenizer
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
let file = File::open(config_filename).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Read the JSON contents of the file as an instance of `User`.
|
||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
||||
|
||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
||||
let api_base_repo = api.repo(Repo::with_revision(
|
||||
base_model_id.to_string(),
|
||||
RepoType::Model,
|
||||
"main".to_string(),
|
||||
));
|
||||
|
||||
api_base_repo.get("tokenizer.json").await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// get tokenizer_config from the Huggingface Hub
|
||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
let file = File::open(tokenizer_config_filename).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
||||
.map_err(|e| {
|
||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
||||
e
|
||||
})
|
||||
.ok()?;
|
||||
|
||||
Some(tokenizer_config)
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
|
@ -2008,16 +2104,6 @@ impl From<InferError> for Event {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
|
|
@ -96,7 +96,7 @@ impl Validation {
|
|||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
|
@ -122,7 +122,7 @@ impl Validation {
|
|||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
|
@ -631,18 +631,51 @@ fn prepare_input(
|
|||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Image {
|
||||
pub data: Vec<u8>,
|
||||
pub mimetype: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Chunk {
|
||||
Text(String),
|
||||
Image(Image),
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<Chunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c {
|
||||
Chunk::Text(text) => output.push_str(text),
|
||||
Chunk::Image(Image { data, mimetype }) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum ValidGrammar {
|
||||
pub enum ValidGrammar {
|
||||
Json(String),
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidParameters {
|
||||
pub struct ValidParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
|
@ -666,7 +699,7 @@ pub(crate) struct ValidParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidStoppingParameters {
|
||||
pub struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
|
@ -677,8 +710,8 @@ pub(crate) struct ValidStoppingParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: Vec<InputChunk>,
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
@ -750,6 +783,9 @@ pub enum ValidationError {
|
|||
InvalidImageContent(String),
|
||||
#[error("Could not fetch image: {0}")]
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
#[error("{0} modality is not supported")]
|
||||
UnsupportedModality(&'static str)
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
Loading…
Reference in New Issue