Rebase TRT-llm (#2331)
* wip
wip
refacto
refacto
Initial setup for CXX binding to TRTLLM
Working FFI call for TGI and TRTLLM backend
Remove unused parameters annd force tokenizer name to be set
Overall build TRTLLM and deps through CMake build system
Enable end to end CMake build
First version loading engines and making it ready for inference
Remembering to check how we can detect support for chunked context
Move to latest TensorRT-LLM version
Specify which default log level to use depending on CMake build type
make leader executor mode working
unconditionally call InitializeBackend on the FFI layer
bind to CUDA::nvml to retrieve compute capabilities at runtime
updated logic and comment to detect cuda compute capabilities
implement the Stream method to send new tokens through a callback
use spdlog release 1.14.1 moving forward
update trtllm to latest version a96cccafcf6365c128f004f779160951f8c0801c
correctly tell cmake to build dependent tensorrt-llm required libraries
create cmake install target to put everything relevant in installation folder
add auth_token CLI argument to provide hf hub authentification token
allow converting huggingface::tokenizers error to TensorRtLlmBackendError
use correct include for spdlog
include guard to build example in cmakelists
working setup of the ffi layer
remove fmt import
use external fmt lib
end to end ffi flow working
make sure to track include/ffi.h to trigger rebuild from cargo
impl the rust backend which currently cannot move the actual computation in background thread
expose shutdown function at ffi layer
impl RwLock scenario for TensorRtLllmBackend
oops missing c++ backend definitions
compute the number of maximum new tokens for each request independently
make sure the context is not dropped in the middle of the async decoding.
remove unnecessary log
add all the necessary plumbery to return the generated content
update invalid doc in cpp file
correctly forward back the log probabilities
remove unneeded scope variable for now
refactor Stream impl for Generation to factorise code
expose the internal missing start/queue timestamp
forward tgi parameters rep/freq penalty
add some more validation about grammar not supported
define a shared struct to hold the result of a decoding step
expose information about potential error happening while decoding
remove logging
add logging in case of decoding error
make sure executor_worker is provided
add initial Dockerfile for TRTLLM backend
add some more information in CMakeLists.txt to correctly install executorWorker
add some more information in CMakeLists.txt to correctly find and install nvrtc wrapper
simplify prebuilt trtllm libraries name definition
do the same name definition stuff for tensorrt_llm_executor_static
leverage pkg-config to probe libraries paths and reuse new install structure from cmake
fix bad copy/past missing nvinfer linkage direction
align all the linker search dependency
add missing pkgconfig folder for MPI in Dockerfile
correctly setup linking search path for runtime layer
fix missing / before tgi lib path
adding missing ld_library_path for cuda stubs in Dockerfile
update tgi entrypoint
commenting out Python part for TensorRT installation
refactored docker image
move to TensorRT-LLM v0.11.0
make docker linter happy with same capitalization rule
fix typo
refactor the compute capabilities detection along with num gpus
update TensorRT-LLM to latest version
update TensorRT install script to latest
update build.rs to link to cuda 12.5
add missing dependant libraries for linking
clean up a bit
install to decoder_attention target
add some custom stuff for nccl linkage
fix envvar CARGO_CFG_TARGET_ARCH set at runtime vs compile time
use std::env::const::ARCH
make sure variable live long enough...
look for cuda 12.5
add some more basic info in README.md
* Rebase.
* Fix autodocs.
* Let's try to enable trtllm backend.
* Ignore backends/v3 by default.
* Fixing client.
* Fix makefile + autodocs.
* Updating the schema thing + redocly.
* Fix trtllm lint.
* Adding pb files ?
* Remove cargo fmt temporarily.
* ?
* Tmp.
* Remove both check + clippy ?
* Backporting telemetry.
* Backporting 457fb0a1
* Remove PB from git.
* Fixing PB with default member backends/client
* update TensorRT-LLM to latest version
* provided None for api_key
* link against libtensorrt_llm and not libtensorrt-llm
---------
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
Co-authored-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
parent
53aec27328
commit
2b19d671b4
|
@ -2,3 +2,5 @@ aml
|
|||
target
|
||||
server/transformers
|
||||
server/flash-attention
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
|
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
|
||||
- name: Install router
|
||||
id: install-router
|
||||
run: cargo install --path router/
|
||||
run: cargo install --path backends/v3/
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
|
@ -41,5 +41,5 @@ jobs:
|
|||
|
||||
- name: Check that documentation is up-to-date
|
||||
run: |
|
||||
npm install -g swagger-cli
|
||||
npm install -g @redocly/cli
|
||||
python update_doc.py --check
|
||||
|
|
|
@ -3,6 +3,10 @@ target
|
|||
router/tokenizer.json
|
||||
*__pycache__*
|
||||
|
||||
backends/v3/src/client/pb
|
||||
backends/client/src/v2/pb
|
||||
backends/client/src/v3/pb
|
||||
|
||||
# ROCm auto-generated files
|
||||
*.hip
|
||||
server/exllamav2_kernels/exllamav2_kernels/hip/
|
||||
|
|
|
@ -13,8 +13,8 @@ repos:
|
|||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
hooks:
|
||||
- id: fmt
|
||||
- id: cargo-check
|
||||
- id: fmt
|
||||
- id: clippy
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.0
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.
|
||||
# See https://redoc.ly/docs/cli/ for more information.
|
||||
docs/openapi.json:
|
||||
no-empty-servers:
|
||||
- '#/openapi'
|
||||
spec:
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/grammar/nullable'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'
|
||||
- >-
|
||||
#/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum
|
||||
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||
- '#/components/schemas/ToolChoice/nullable'
|
||||
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||
no-invalid-media-type-examples:
|
||||
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/424/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/429/content/application~1json/example'
|
||||
- '#/paths/~1/post/responses/500/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/422/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/424/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/429/content/application~1json/example'
|
||||
- '#/paths/~1generate/post/responses/500/content/application~1json/example'
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example
|
||||
- >-
|
||||
#/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example
|
||||
- '#/paths/~1tokenize/post/responses/404/content/application~1json/example'
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/422/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/424/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/429/content/application~1json/example
|
||||
- >-
|
||||
#/paths/~1v1~1completions/post/responses/500/content/application~1json/example
|
||||
operation-4xx-response:
|
||||
- '#/paths/~1health/get/responses'
|
||||
- '#/paths/~1info/get/responses'
|
||||
- '#/paths/~1metrics/get/responses'
|
||||
no-unused-components:
|
||||
- '#/components/schemas/Completion'
|
||||
security-defined:
|
||||
- '#/paths/~1/post'
|
||||
- '#/paths/~1generate/post'
|
||||
- '#/paths/~1generate_stream/post'
|
||||
- '#/paths/~1health/get'
|
||||
- '#/paths/~1info/get'
|
||||
- '#/paths/~1metrics/get'
|
||||
- '#/paths/~1tokenize/post'
|
||||
- '#/paths/~1v1~1chat~1completions/post'
|
||||
- '#/paths/~1v1~1completions/post'
|
File diff suppressed because it is too large
Load Diff
17
Cargo.toml
17
Cargo.toml
|
@ -1,9 +1,18 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"benchmark",
|
||||
"router",
|
||||
"router/client",
|
||||
"router/grpc-metadata",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
"backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher"
|
||||
]
|
||||
default-members = [
|
||||
"benchmark",
|
||||
"backends/v3",
|
||||
"backends/grpc-metadata",
|
||||
# "backends/trtllm",
|
||||
"backends/client",
|
||||
"launcher"
|
||||
]
|
||||
resolver = "2"
|
||||
|
@ -18,6 +27,8 @@ homepage = "https://github.com/huggingface/text-generation-inference"
|
|||
base64 = "0.22.0"
|
||||
tokenizers = { version = "0.19.1", features = ["http"] }
|
||||
hf-hub = { version = "0.3.1", features = ["tokio"] }
|
||||
metrics = { version = "0.23.0" }
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
|
||||
[profile.release]
|
||||
incremental = true
|
||||
|
|
|
@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
|
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# All the tooling for CUDA
|
||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
WORKDIR /usr/src/tgi/backends/trtllm
|
||||
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
|
||||
|
||||
COPY . /usr/src/tgi
|
||||
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
|
||||
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
|
||||
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
|
||||
|
||||
# All the tooling for Rust
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
# Include CUDA related libraries and tools to the Rust based image
|
||||
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
|
||||
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
|
||||
ENV PATH=/usr/local/cuda/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
|
||||
|
||||
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3
|
|
@ -11,6 +11,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
|
@ -33,6 +34,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
|
@ -34,6 +35,7 @@ COPY rust-toolchain.toml rust-toolchain.toml
|
|||
COPY proto proto
|
||||
COPY benchmark benchmark
|
||||
COPY router router
|
||||
COPY backends backends
|
||||
COPY launcher launcher
|
||||
RUN cargo build --profile release-opt
|
||||
|
||||
|
|
6
Makefile
6
Makefile
|
@ -5,13 +5,13 @@ install-server-cpu:
|
|||
cd server && make install-server
|
||||
|
||||
install-router:
|
||||
cd router && cargo install --path .
|
||||
cargo install --path backends/v3/
|
||||
|
||||
install-launcher:
|
||||
cd launcher && cargo install --path .
|
||||
cargo install --path launcher/
|
||||
|
||||
install-benchmark:
|
||||
cd benchmark && cargo install --path .
|
||||
cargo install --path benchmark/
|
||||
|
||||
install: install-server install-router install-launcher
|
||||
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
cmake_minimum_required(VERSION 3.20)
|
||||
|
||||
project(tgi-trtllm-backend VERSION 1.0.0)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
|
||||
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
|
||||
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
|
||||
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
|
||||
|
||||
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
|
||||
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
|
||||
|
||||
#### External dependencies ####
|
||||
include(cmake/fmt.cmake)
|
||||
include(cmake/json.cmake)
|
||||
include(cmake/spdlog.cmake)
|
||||
include(cmake/trtllm.cmake)
|
||||
|
||||
# Let's build TRTLLM as part of CMake
|
||||
add_subdirectory("${trtllm_SOURCE_DIR}/cpp" "${trtllm_SOURCE_DIR}/..")
|
||||
|
||||
# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so
|
||||
set_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)
|
||||
|
||||
# TGI TRTLLM Backend definition
|
||||
add_library(tgi_trtllm_backend_impl STATIC include/backend.h lib/backend.cpp include/hardware.h)
|
||||
include_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
target_include_directories(tgi_trtllm_backend_impl PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
target_include_directories(tgi_trtllm_backend_impl PUBLIC "${trtllm_SOURCE_DIR}/cpp/include")
|
||||
target_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper CUDA::cudart CUDA::nvml)
|
||||
target_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt)
|
||||
|
||||
# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back
|
||||
install(TARGETS tgi_trtllm_backend_impl tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention executorWorker)
|
||||
install(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)
|
||||
|
||||
#### Unit Tests ####
|
||||
if (${TGI_TRTLLM_BACKEND_BUILD_TESTS})
|
||||
message(STATUS "Building tests")
|
||||
FetchContent_Declare(
|
||||
Catch2
|
||||
GIT_REPOSITORY https://github.com/catchorg/Catch2
|
||||
GIT_TAG v3.6.0
|
||||
)
|
||||
FetchContent_MakeAvailable(Catch2)
|
||||
|
||||
# add_executable(tgi_trtllm_backend_tests tests/infer_test.cpp)
|
||||
# target_link_libraries(tgi_trtllm_backend_tests PRIVATE tgi_trtllm_backend_impl Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog fmt::fmt CUDA::cudart CUDA::nvml)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
|
||||
include(CTest)
|
||||
include(Catch)
|
||||
# catch_discover_tests(tgi_trtllm_backend_tests)
|
||||
endif ()
|
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "text-generation-backends-trtllm"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
cxx = "1.0"
|
||||
text-generation-router = { path = "../../router" }
|
||||
tokenizers = { version = "0.19", features = ["hf-hub"] }
|
||||
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.15"
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
thiserror = "1.0.62"
|
||||
tracing = "0.1"
|
||||
tracing-opentelemetry = "0.24"
|
||||
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
|
||||
log = { version = "0.4", features = [] }
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
cxx-build = { version = "1.0", features = ["parallel"] }
|
||||
pkg-config = "0.3"
|
|
@ -0,0 +1,100 @@
|
|||
ARG CUDA_ARCH_LIST="75-real;80-real;86-real;89-real;90-real"
|
||||
ARG OMPI_VERSION="4.1.6"
|
||||
|
||||
# Build dependencies resolver stage
|
||||
FROM lukemathwalker/cargo-chef:latest AS chef
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
# CUDA dependent dependencies resolver stage
|
||||
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt update && apt install -y \
|
||||
build-essential \
|
||||
cmake \
|
||||
curl \
|
||||
gcc \
|
||||
g++ \
|
||||
git \
|
||||
git-lfs \
|
||||
libssl-dev \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
python3 \
|
||||
python3-setuptools \
|
||||
tar \
|
||||
wget
|
||||
|
||||
ENV TGI_INSTALL_PREFIX=/usr/local/tgi
|
||||
ENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt
|
||||
|
||||
# Install OpenMPI
|
||||
FROM cuda-builder AS mpi-builder
|
||||
ARG OMPI_VERSION
|
||||
|
||||
ENV OMPI_TARBALL_FILENAME="openmpi-$OMPI_VERSION.tar.bz2"
|
||||
RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILENAME" -P /opt/src && \
|
||||
mkdir /usr/src/mpi && \
|
||||
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
|
||||
cd /usr/src/mpi && \
|
||||
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --without-slurm && \
|
||||
make -j all && \
|
||||
make install && \
|
||||
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
|
||||
|
||||
# Install TensorRT
|
||||
FROM cuda-builder AS trt-builder
|
||||
COPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh
|
||||
RUN chmod +x /opt/install_tensorrt.sh && \
|
||||
/opt/install_tensorrt.sh
|
||||
|
||||
# Build Backend
|
||||
FROM cuda-builder AS tgi-builder
|
||||
WORKDIR /usr/src/text-generation-inference
|
||||
|
||||
# Install Rust
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \
|
||||
chmod -R a+w /root/.rustup && \
|
||||
chmod -R a+w /root/.cargo
|
||||
|
||||
ENV PATH="/root/.cargo/bin:$PATH"
|
||||
RUN cargo install cargo-chef
|
||||
|
||||
# Cache dependencies
|
||||
COPY --from=planner /usr/src/text-generation-inference/recipe.json .
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
|
||||
# Build actual TGI
|
||||
ARG CUDA_ARCH_LIST
|
||||
ENV CMAKE_PREFIX_PATH="/usr/local/mpi:/usr/local/tensorrt:$CMAKE_PREFIX_PATH"
|
||||
ENV LD_LIBRARY_PATH="/usr/local/mpi/lib:$LD_LIBRARY_PATH"
|
||||
ENV PKG_CONFIG_PATH="/usr/local/mpi/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
|
||||
COPY . .
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$TGI_INSTALL_PREFIX/lib" && \
|
||||
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release --bin text-generation-backends-trtllm
|
||||
|
||||
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
|
||||
WORKDIR /usr/local/tgi/bin
|
||||
|
||||
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
|
||||
|
||||
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
|
||||
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt
|
||||
COPY --from=tgi-builder /usr/local/tgi /usr/local/tgi
|
||||
COPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher
|
||||
|
||||
FROM runtime
|
||||
|
||||
LABEL co.huggingface.vendor="Hugging Face Inc."
|
||||
LABEL org.opencontainers.image.authors="hardware@hf.co"
|
||||
|
||||
ENTRYPOINT ["./text-generation-launcher"]
|
||||
CMD ["--executor-worker", "/usr/local/tgi/bin/executorWorker"]
|
|
@ -0,0 +1,46 @@
|
|||
# Text Generation Inference - TensorRT-LLM Backend Implementation
|
||||
|
||||
## Description
|
||||
|
||||
This folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API
|
||||
|
||||
## Simplified Request Sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
actor User
|
||||
participant TextGenerationInference.HttpServer
|
||||
participant TextGenerationInference.TensorRtLlmBackend
|
||||
participant TextGenerationInference.TensorRtLlmWorkerThread
|
||||
participant TensorRtLlm.Executor
|
||||
participant Nvidia.Gpu
|
||||
User ->> TextGenerationInference.HttpServer: POST /generate
|
||||
TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters
|
||||
TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher
|
||||
activate Nvidia.Gpu
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier
|
||||
rect rgb(10, 92, 54)
|
||||
loop every 100us
|
||||
rect rgb(15, 81, 50)
|
||||
alt Acquire lock to query executor
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated
|
||||
else There are new generated tokens
|
||||
TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens
|
||||
TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)
|
||||
rect rgb(11, 110, 79)
|
||||
alt Generated token is final
|
||||
TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection
|
||||
else Generated token is not final
|
||||
TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
deactivate Nvidia.Gpu
|
||||
end
|
||||
end
|
||||
|
||||
```
|
|
@ -0,0 +1,150 @@
|
|||
use cxx_build::CFG;
|
||||
use pkg_config;
|
||||
use std::env;
|
||||
use std::env::consts::ARCH;
|
||||
use std::path::{absolute, PathBuf};
|
||||
|
||||
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
|
||||
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
|
||||
const CUDA_REQUIRED_VERSION: &str = "12.5";
|
||||
const MPI_REQUIRED_VERSION: &str = "4.1";
|
||||
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
|
||||
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
|
||||
const NCCL_ROOT_DIR: Option<&str> = option_env!("NCCL_ROOT_DIR");
|
||||
|
||||
// Dependencies
|
||||
const BACKEND_DEPS: [&str; 2] = ["tgi_trtllm_backend_impl", "tgi_trtllm_backend"];
|
||||
const CUDA_TRANSITIVE_DEPS: [&str; 4] = ["cuda", "cudart", "cublas", "nvidia-ml"];
|
||||
const TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [
|
||||
("dylib", "tensorrt_llm"),
|
||||
("static", "tensorrt_llm_executor_static"),
|
||||
("dylib", "tensorrt_llm_nvrtc_wrapper"),
|
||||
("dylib", "nvinfer_plugin_tensorrt_llm"),
|
||||
("dylib", "decoder_attention"),
|
||||
];
|
||||
|
||||
macro_rules! probe {
|
||||
($name: expr, $version: expr) => {
|
||||
if let Err(_) = pkg_config::probe_library($name) {
|
||||
pkg_config::probe_library(&format!("{}-{}", $name, $version))
|
||||
.expect(&format!("Failed to locate {}", $name));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {
|
||||
// Build the backend implementation through CMake
|
||||
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
|
||||
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
|
||||
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
|
||||
|
||||
let mut install_path = PathBuf::from(install_path);
|
||||
if !install_path.is_absolute() {
|
||||
install_path = absolute(out_dir).expect("cannot happen").join(install_path);
|
||||
}
|
||||
|
||||
let _ = cmake::Config::new(".")
|
||||
.uses_cxx11()
|
||||
.generator("Ninja")
|
||||
.profile(match is_debug {
|
||||
true => "Debug",
|
||||
false => "Release",
|
||||
})
|
||||
.env("OPT_LEVEL", opt_level)
|
||||
.define("CMAKE_INSTALL_PREFIX", &install_path)
|
||||
.define("CMAKE_CUDA_COMPILER", "/usr/local/cuda/bin/nvcc")
|
||||
.define("TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST", cuda_arch_list)
|
||||
.define("TGI_TRTLLM_BACKEND_TRT_ROOT", tensorrt_path)
|
||||
.build();
|
||||
|
||||
// Additional transitive CMake dependencies
|
||||
let deps_folder = out_dir.join("build").join("_deps");
|
||||
for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {
|
||||
let dep_name = match is_debug {
|
||||
true => format!("{}d", dependency),
|
||||
false => String::from(dependency),
|
||||
};
|
||||
let dep_path = deps_folder.join(format!("{}-build", dependency));
|
||||
println!("cargo:rustc-link-search={}", dep_path.display());
|
||||
println!("cargo:rustc-link-lib=static={}", dep_name);
|
||||
}
|
||||
|
||||
// Emit linkage information from the artifacts we just built
|
||||
let install_lib_path = install_path.join("lib");
|
||||
|
||||
println!(
|
||||
r"cargo:warning=Adding link search path: {}",
|
||||
install_lib_path.display()
|
||||
);
|
||||
println!(r"cargo:rustc-link-search={}", install_lib_path.display());
|
||||
|
||||
(PathBuf::from(install_path), deps_folder)
|
||||
}
|
||||
|
||||
fn build_ffi_layer(deps_folder: &PathBuf) {
|
||||
CFG.include_prefix = "backends/trtllm";
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.static_flag(true)
|
||||
.include(deps_folder.join("fmt-src").join("include"))
|
||||
.include(deps_folder.join("spdlog-src").join("include"))
|
||||
.include(deps_folder.join("json-src").join("include"))
|
||||
.include(deps_folder.join("trtllm-src").join("cpp").join("include"))
|
||||
.include("/usr/local/cuda/include")
|
||||
.include("/usr/local/tensorrt/include")
|
||||
.file("src/ffi.cpp")
|
||||
.std("c++20")
|
||||
.compile("tgi_trtllm_backend");
|
||||
|
||||
println!("cargo:rerun-if-changed=CMakeLists.txt");
|
||||
println!("cargo:rerun-if-changed=include/backend.h");
|
||||
println!("cargo:rerun-if-changed=lib/backend.cpp");
|
||||
println!("cargo:rerun-if-changed=include/ffi.h");
|
||||
println!("cargo:rerun-if-changed=src/ffi.cpp");
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Misc variables
|
||||
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||
let build_profile = env::var("PROFILE").unwrap();
|
||||
let (is_debug, opt_level) = match build_profile.as_ref() {
|
||||
"debug" => (true, "0"),
|
||||
_ => (false, "3"),
|
||||
};
|
||||
|
||||
// Build the backend
|
||||
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
|
||||
|
||||
// Build the FFI layer calling the backend above
|
||||
build_ffi_layer(&deps_folder);
|
||||
|
||||
// Emit linkage search path
|
||||
probe!("ompi", MPI_REQUIRED_VERSION);
|
||||
|
||||
// Probe CUDA & co. with pkg-config
|
||||
CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {
|
||||
probe!(name, CUDA_REQUIRED_VERSION);
|
||||
});
|
||||
|
||||
// NCCL is slightly trickier because it might not have a pkgconfig installed
|
||||
let nccl_library_path_default = format!("/usr/local/{}-linux-gnu", ARCH);
|
||||
let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);
|
||||
println!(r"cargo:rustc-link-search=native={}", nccl_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nccl");
|
||||
|
||||
// TensorRT
|
||||
let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt/lib");
|
||||
println!(r"cargo:rustc-link-search=native={}", tensort_library_path);
|
||||
println!("cargo:rustc-link-lib=dylib=nvinfer");
|
||||
|
||||
// TensorRT-LLM
|
||||
TENSORRT_LLM_TRANSITIVE_DEPS
|
||||
.iter()
|
||||
.for_each(|(link_type, name)| {
|
||||
println!("cargo:rustc-link-lib={}={}", link_type, name);
|
||||
});
|
||||
|
||||
// Backend
|
||||
BACKEND_DEPS.iter().for_each(|name| {
|
||||
println!("cargo:rustc-link-lib=static={}", name);
|
||||
});
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 11.0.1
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
|
@ -0,0 +1,5 @@
|
|||
fetchcontent_declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
|
||||
)
|
||||
fetchcontent_makeavailable(json)
|
|
@ -0,0 +1,17 @@
|
|||
set(SPDLOG_USE_FMT ON)
|
||||
set(SPDLOG_BUILD_SHARED OFF)
|
||||
set(SPDLOG_FMT_EXTERNAL ON)
|
||||
|
||||
# Define the level at which SPDLOG_ compilation level is defined
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)
|
||||
else ()
|
||||
add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_INFO)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
spdlog
|
||||
GIT_REPOSITORY https://github.com/gabime/spdlog.git
|
||||
GIT_TAG v1.14.1
|
||||
)
|
||||
fetchcontent_makeavailable(spdlog)
|
|
@ -0,0 +1,42 @@
|
|||
set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})
|
||||
set(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})
|
||||
|
||||
set(USE_CXX11_ABI ON)
|
||||
set(BUILD_PYT OFF)
|
||||
set(BUILD_PYBIND OFF)
|
||||
set(BUILD_MICRO_BENCHMARKS OFF)
|
||||
set(BUILD_BENCHMARKS OFF)
|
||||
set(BUILD_TESTS OFF)
|
||||
set(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})
|
||||
|
||||
message(STATUS "Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
set(FAST_BUILD ON)
|
||||
set(NVTX_DISABLE OFF)
|
||||
else ()
|
||||
set(FAST_BUILD OFF)
|
||||
set(FAST_MATH ON)
|
||||
set(NVTX_DISABLE ON)
|
||||
endif ()
|
||||
|
||||
fetchcontent_declare(
|
||||
trtllm
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
|
||||
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
|
||||
GIT_SHALLOW FALSE
|
||||
)
|
||||
fetchcontent_makeavailable(trtllm)
|
||||
|
||||
message(STATUS "Found TensorRT-LLM: ${trtllm_SOURCE_DIR}")
|
||||
execute_process(COMMAND git lfs install WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
execute_process(COMMAND git lfs pull WORKING_DIRECTORY "${trtllm_SOURCE_DIR}/")
|
||||
|
||||
# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here
|
||||
set(TRTLLM_NVRTC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE INTERNAL "nvrtc wrapper library name")
|
||||
set(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}"
|
||||
CACHE INTERNAL "nvrtc wrapper library path")
|
||||
|
||||
# The same Executor Static library
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE INTERNAL "executor_static library name")
|
||||
set(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH "${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}" CACHE INTERNAL "executor_static library path")
|
|
@ -0,0 +1,121 @@
|
|||
//
|
||||
// Created by Morgan Funtowicz on 6/30/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_H
|
||||
#define TGI_TRTLLM_BACKEND_H
|
||||
|
||||
#include <cmath>
|
||||
#include <filesystem>
|
||||
#include <span>
|
||||
#include <vector>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <tensorrt_llm/runtime/common.h>
|
||||
#include <tensorrt_llm/executor/executor.h>
|
||||
#include <tensorrt_llm/plugins/api/tllmPlugin.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
using RequestId = tle::IdType;
|
||||
using TokenId = tle::TokenIdType;
|
||||
|
||||
/**
|
||||
* Initialize all the components required by TRTLLM.
|
||||
* It is required to call this function before attempting to load any engine
|
||||
*/
|
||||
void InitializeBackend();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param config TensorRT-LLM configuration object
|
||||
* @param workerPath Path to the "executorWorker" provided by TensorRT-LLM when using orchestrator mode
|
||||
* @return
|
||||
*/
|
||||
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
|
||||
|
||||
/**
|
||||
* Get the sampling configuration from the parameters provided by TGI
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
tle::SamplingConfig GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
class TensorRtLlmBackend {
|
||||
private:
|
||||
const json config;
|
||||
tle::Executor executor;
|
||||
|
||||
public:
|
||||
explicit TensorRtLlmBackend(
|
||||
const std::filesystem::path &engineFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
);
|
||||
|
||||
/**
|
||||
* Indicate if the backend is ready to accept incoming request
|
||||
* @return true if ready, false otherwise
|
||||
*/
|
||||
[[nodiscard]] bool IsReady() const;
|
||||
|
||||
/**
|
||||
* Query the executor for the number of token available for pulling
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard]] size_t NumResponsesReady() const;
|
||||
|
||||
/**
|
||||
* Submit a new generation task to the executor
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return Request id related to this generation for reference
|
||||
*/
|
||||
[[nodiscard]] RequestId Submit(
|
||||
const std::vector<TokenId> &tokens,
|
||||
int32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed
|
||||
);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param requestId The request id to poll the generation results
|
||||
* @return
|
||||
*/
|
||||
std::vector<tle::Response> Poll(RequestId requestId);
|
||||
|
||||
/**
|
||||
* Stop the underlying executor
|
||||
*/
|
||||
void Shutdown();
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_H
|
|
@ -0,0 +1,75 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/11/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_FFI_H
|
||||
#define TGI_TRTLLM_BACKEND_FFI_H
|
||||
|
||||
#include <cstddef>
|
||||
#include "backend.h"
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
class TensorRtLlmBackendImpl;
|
||||
}
|
||||
|
||||
#include "backends/trtllm/src/lib.rs.h"
|
||||
|
||||
|
||||
namespace huggingface::tgi::backends {
|
||||
|
||||
// struct GenerationContext;
|
||||
|
||||
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
|
||||
public:
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @param executorWorker
|
||||
*/
|
||||
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
|
||||
|
||||
/***
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
bool IsReady() const;
|
||||
|
||||
/***
|
||||
*
|
||||
* @param tokens
|
||||
* @param topK
|
||||
* @param topP
|
||||
* @param temperature
|
||||
* @param repetition_penalty
|
||||
* @param frequency_penalty
|
||||
* @param seed
|
||||
* @return
|
||||
*/
|
||||
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
|
||||
uint64_t
|
||||
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
|
||||
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
|
||||
|
||||
/***
|
||||
*
|
||||
* @param requestId
|
||||
* @param ctx
|
||||
* @param callback
|
||||
* @return
|
||||
*/
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback);
|
||||
};
|
||||
|
||||
/***
|
||||
*
|
||||
* @param engineFolder
|
||||
* @return
|
||||
*/
|
||||
std::unique_ptr<TensorRtLlmBackendImpl> CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker);
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_FFI_H
|
|
@ -0,0 +1,59 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/23/24.
|
||||
//
|
||||
|
||||
#ifndef TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
#define TGI_TRTLLM_BACKEND_HARDWARE_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <fmt/base.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
namespace huggingface::hardware::cuda {
|
||||
|
||||
#define AMPERE_SM_MAJOR 8
|
||||
#define HOPPER_SM_MAJOR 8
|
||||
|
||||
/**
|
||||
* Store information about the version of the CUDA Compute Capabilities detected on the device
|
||||
*/
|
||||
struct CudaComputeCapabilities {
|
||||
int32_t major;
|
||||
int32_t minor;
|
||||
|
||||
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
|
||||
|
||||
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
|
||||
};
|
||||
|
||||
CudaComputeCapabilities GetCudaComputeCapabilities() {
|
||||
// Get the compute capabilities of the current hardware
|
||||
nvmlDevice_t device;
|
||||
CudaComputeCapabilities capabilities{0, 0};
|
||||
if (nvmlDeviceGetHandleByIndex_v2(0, &device) == NVML_SUCCESS) {
|
||||
SPDLOG_DEBUG("Successfully acquired nvmlDevice_t = 0");
|
||||
if (nvmlDeviceGetCudaComputeCapability(device, &capabilities.major, &capabilities.minor) == NVML_SUCCESS) {
|
||||
SPDLOG_INFO("Detected sm_{:d}{:d} compute capabilities", capabilities.major, capabilities.minor);
|
||||
}
|
||||
}
|
||||
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the number of GPU detected. If no GPU is detected, return size_t::max()
|
||||
* @return
|
||||
*/
|
||||
std::optional<size_t> GetNumDevices() {
|
||||
uint32_t numGpus = 0;
|
||||
if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {
|
||||
return std::optional(numGpus);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //TGI_TRTLLM_BACKEND_HARDWARE_H
|
|
@ -0,0 +1,146 @@
|
|||
#include <fstream>
|
||||
|
||||
#include <fmt/ranges.h>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include <nvml.h>
|
||||
|
||||
#include "backend.h"
|
||||
#include "hardware.h"
|
||||
|
||||
void huggingface::tgi::backends::InitializeBackend() {
|
||||
SPDLOG_INFO("Initializing Backend...");
|
||||
nvmlInit_v2();
|
||||
initTrtLlmPlugins();
|
||||
|
||||
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
|
||||
if (numGpus.has_value()) {
|
||||
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
|
||||
} else {
|
||||
SPDLOG_WARN("Failed to detected Nvidia GPU(s) on the system");
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]]
|
||||
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
|
||||
tle::ExecutorConfig execConfig(1);
|
||||
|
||||
// Retrieve the compute capabilities to enable some options at runtime
|
||||
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
|
||||
|
||||
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
|
||||
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
|
||||
SPDLOG_INFO("Detected single engine deployment, using leader mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kLEADER,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt
|
||||
));
|
||||
} else { // Multiple engines -> using orchestrator mode (MPI involved)
|
||||
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
|
||||
execConfig.setParallelConfig(tle::ParallelConfig(
|
||||
tle::CommunicationType::kMPI,
|
||||
tle::CommunicationMode::kORCHESTRATOR,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
tle::OrchestratorConfig(true, workerPath, nullptr, true)
|
||||
));
|
||||
}
|
||||
|
||||
// Define some configuration variables
|
||||
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
|
||||
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
|
||||
return execConfig;
|
||||
}
|
||||
|
||||
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
|
||||
uint32_t topK,
|
||||
float_t topP,
|
||||
float_t temperature,
|
||||
float_t repetition_penalty,
|
||||
float_t frequency_penalty,
|
||||
uint64_t seed) {
|
||||
return tle::SamplingConfig(
|
||||
1, // TGI only use a single beam
|
||||
topK,
|
||||
topP,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
seed,
|
||||
temperature,
|
||||
temperature,
|
||||
std::nullopt,
|
||||
repetition_penalty,
|
||||
std::nullopt,
|
||||
frequency_penalty
|
||||
);
|
||||
}
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
|
||||
const std::filesystem::path &enginesFolder,
|
||||
const std::filesystem::path &executorWorker
|
||||
) :
|
||||
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
|
||||
executor(
|
||||
enginesFolder,
|
||||
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
|
||||
GetExecutorConfig(config, executorWorker.string()
|
||||
)) {
|
||||
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
|
||||
}
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
|
||||
return executor.canEnqueueRequests();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned number of requests needs to be consumed")]]
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
|
||||
return executor.getNumResponsesReady();
|
||||
}
|
||||
|
||||
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
|
||||
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
|
||||
const std::vector<tle::TokenIdType> &tokens,
|
||||
const int32_t topK,
|
||||
const float_t topP,
|
||||
const float_t temperature,
|
||||
const float_t repetition_penalty,
|
||||
const float_t frequency_penalty,
|
||||
const uint64_t seed
|
||||
) {
|
||||
#ifdef NDEBUG
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
|
||||
tokens.size(),
|
||||
executor.getLatestIterationStats().back().numActiveRequests
|
||||
);
|
||||
#else
|
||||
SPDLOG_DEBUG(
|
||||
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
|
||||
fmt::join(tokens, ", "),
|
||||
executor.getLatestIterationStats().front().numActiveRequests
|
||||
);
|
||||
#endif
|
||||
|
||||
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
|
||||
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
|
||||
|
||||
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
const auto output = tle::OutputConfig(true, false, false, true, false);
|
||||
return executor.enqueueRequest(
|
||||
tle::Request{tokens, maxNewTokens, true, sampling, output});
|
||||
}
|
||||
|
||||
[[nodiscard("Generated tokens result must be used")]]
|
||||
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
|
||||
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
|
||||
return executor.awaitResponses(requestId);
|
||||
}
|
||||
|
||||
|
||||
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
|
||||
SPDLOG_INFO("Shutting down executor");
|
||||
executor.shutdown();
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="10.2.0.19"
|
||||
CUDA_VER="12.5"
|
||||
CUDNN_VER="9.2.1.18-1"
|
||||
NCCL_VER="2.22.3-1+cuda12.5"
|
||||
CUBLAS_VER="12.5.3.2-1"
|
||||
NVRTC_VER="12.5.82-1"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
--TRT_VER=?*) TRT_VER="${i#*=}";;
|
||||
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
|
||||
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
|
||||
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
|
||||
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
|
||||
*) ;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
install_ubuntu_requirements() {
|
||||
apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
|
||||
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
|
||||
dpkg -i cuda-keyring_1.0-1_all.deb
|
||||
|
||||
apt-get update
|
||||
if [[ $(apt list --installed | grep libcudnn9) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcudnn9*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libnccl) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libnccl*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep libcublas) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages libcublas*
|
||||
fi
|
||||
if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then
|
||||
apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*
|
||||
fi
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}
|
||||
apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}
|
||||
apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}
|
||||
# NVRTC static library doesn't exist in NGC PyTorch container.
|
||||
NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
}
|
||||
|
||||
install_centos_requirements() {
|
||||
CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g')
|
||||
yum -y update
|
||||
yum -y install epel-release
|
||||
yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}
|
||||
yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}
|
||||
yum clean all
|
||||
}
|
||||
|
||||
install_tensorrt() {
|
||||
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
TRT_CUDA_VERSION="12.5"
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
ARCH=${TRT_TARGETARCH}
|
||||
if [ -z "$ARCH" ];then ARCH=$(uname -m);fi
|
||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||
fi
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||
rm -rf /tmp/TensorRT.tar
|
||||
}
|
||||
|
||||
# Install base packages depending on the base OS
|
||||
ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||
case "$ID" in
|
||||
debian)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
ubuntu)
|
||||
install_ubuntu_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
centos)
|
||||
install_centos_requirements
|
||||
install_tensorrt
|
||||
;;
|
||||
*)
|
||||
echo "Unable to determine OS..."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
|
@ -0,0 +1,329 @@
|
|||
use std::future::Future;
|
||||
use std::path::Path;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cxx::UniquePtr;
|
||||
use log::{error, warn};
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{sleep, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::{instrument, span, Level};
|
||||
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
|
||||
type InferResult<T> = Result<T, InferError>;
|
||||
|
||||
pub(crate) struct Generation {
|
||||
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
done: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// Holds the user provided input to be executed along with a channel allowing
|
||||
/// to bubble up all the generated tokens for that tokens the to end stream.
|
||||
pub struct GenerationContext {
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
tokens: Vec<u32>,
|
||||
done: Arc<AtomicBool>,
|
||||
queued: Instant,
|
||||
start: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Stream for Generation {
|
||||
type Item = usize;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let interval = POLLING_INTERVAL_US.get_or_init(|| {
|
||||
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
|
||||
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
|
||||
});
|
||||
|
||||
if !self.done.load(Ordering::Relaxed) {
|
||||
let backend = pin!(self.executor.read());
|
||||
let status = match backend.poll(ctx) {
|
||||
Poll::Ready(executor_r) => {
|
||||
let ready = executor_r.num_responses_ready();
|
||||
if ready == 0 {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(Some(ready))
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
};
|
||||
|
||||
let waker = ctx.waker().clone();
|
||||
tokio::spawn(async {
|
||||
sleep(Duration::from_micros(*interval)).await;
|
||||
waker.wake();
|
||||
});
|
||||
|
||||
status
|
||||
} else {
|
||||
Poll::Ready(None) // end of stream
|
||||
}
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
(1, None)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for TensorRtLlmBackendImpl {}
|
||||
unsafe impl Sync for TensorRtLlmBackendImpl {}
|
||||
|
||||
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
|
||||
pub struct TensorRtLlmBackend {
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
|
||||
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
|
||||
// the number of available tokens (read only) in the Generation stream
|
||||
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
|
||||
}
|
||||
|
||||
impl TensorRtLlmBackend {
|
||||
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
|
||||
tokenizer: Tokenizer,
|
||||
engine_folder: P,
|
||||
executor_worker_path: PP,
|
||||
) -> Result<Self, TensorRtLlmBackendError> {
|
||||
Ok(TensorRtLlmBackend {
|
||||
tokenizer: Arc::new(tokenizer),
|
||||
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
|
||||
engine_folder.as_ref().to_str().unwrap(),
|
||||
executor_worker_path.as_ref().to_str().unwrap(),
|
||||
))),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
|
||||
if request.top_n_tokens > 1 {
|
||||
return Err(InferError::ValidationError(
|
||||
ValidationError::TopNTokensDisabled,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Is it really needed? How can it be validated before?
|
||||
if request.parameters.grammar.is_some() {
|
||||
return Err(InferError::ValidationError(ValidationError::Grammar));
|
||||
}
|
||||
|
||||
match request.inputs.len() {
|
||||
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
|
||||
2.. => Err(InferError::GenerationError(
|
||||
"TensorRT-LLM backend don't support multi-chunk".into(),
|
||||
)),
|
||||
1 => match request.inputs.first().expect("Single item-chunk") {
|
||||
Chunk::Text(text) => Ok(text),
|
||||
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn generate(
|
||||
&self,
|
||||
sender: UnboundedSender<InferResult<InferStreamResponse>>,
|
||||
tokens: Vec<u32>,
|
||||
top_k: u32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) {
|
||||
let tokenizer = Arc::clone(&self.tokenizer);
|
||||
let executor = Arc::clone(&self.backend);
|
||||
|
||||
// Let's push this in async context
|
||||
tokio::spawn(async move {
|
||||
// Define the generation state
|
||||
let mut generation = Generation {
|
||||
executor: executor.clone(),
|
||||
done: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
// Define the context over the generation
|
||||
// TODO(asap): Do we really need so many shared-ownership?
|
||||
let ctx = Box::new(GenerationContext {
|
||||
sender: sender.clone(),
|
||||
tokenizer,
|
||||
tokens: vec![],
|
||||
done: Arc::clone(&generation.done),
|
||||
start: None,
|
||||
queued: Instant::now(),
|
||||
});
|
||||
|
||||
// We are leaking the context on-purpose to avoid the box being dropped while there are
|
||||
// still computation ongoing
|
||||
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
|
||||
let ctx_ = Box::leak(ctx);
|
||||
|
||||
// Submit the request to the batcher
|
||||
let request_id = span!(Level::DEBUG, "submit")
|
||||
.in_scope(|| async {
|
||||
let mut handle = executor.write().await;
|
||||
let request_id = handle.pin_mut().submit(
|
||||
&tokens,
|
||||
top_k as i32,
|
||||
top_p,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
seed,
|
||||
);
|
||||
|
||||
request_id
|
||||
})
|
||||
.await;
|
||||
|
||||
while let Some(_) = generation.next().await {
|
||||
let mut executor_w = executor.write().await;
|
||||
let executor = executor_w.pin_mut();
|
||||
|
||||
span!(Level::DEBUG, "decode")
|
||||
.in_scope(|| async {
|
||||
unsafe {
|
||||
executor.stream_tokens(
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||
let inner_ctx = &mut *ctx;
|
||||
|
||||
// Update the timestamp at which the request started effectively
|
||||
// Can be a bit off, would need to be before the callback, let's see
|
||||
inner_ctx.start.get_or_insert(Instant::now());
|
||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
||||
|
||||
// Ensure we are not running into errors
|
||||
let parcel = if !step.has_error {
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
let special = inner_ctx
|
||||
.tokenizer
|
||||
.get_added_vocabulary()
|
||||
.is_special_token(&text);
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
if step.is_final {
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||
queued: inner_ctx.queued,
|
||||
})
|
||||
} else {
|
||||
Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
error!("Error caught while decoding: {}", &step.error_msg);
|
||||
Err(InferError::GenerationError(step.error_msg))
|
||||
};
|
||||
|
||||
// Send the parcel to the client
|
||||
inner_ctx
|
||||
.sender
|
||||
.send(parcel)
|
||||
.expect("Failed to sent msg through the channel");
|
||||
},
|
||||
);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// "Properly" free the shared context...
|
||||
// TODO: clean that piece of sh** asap
|
||||
unsafe {
|
||||
let _ = Box::from_raw(ctx_);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for TensorRtLlmBackend {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
|
||||
// Let's add a few more validation
|
||||
let input = TensorRtLlmBackend::validate(&request)?;
|
||||
|
||||
// Channel to stream the generated token as they come from the worker thread back to the transport layer
|
||||
let (sender, receiver) = unbounded_channel();
|
||||
|
||||
// Unpack parameters
|
||||
let params = &request.parameters;
|
||||
|
||||
// Preprocess the inputs to send to TRTLLM backend
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(input.as_str(), true)
|
||||
.map_err(|e| InferError::GenerationError(e.to_string()))?;
|
||||
|
||||
// Generate the response
|
||||
self.generate(
|
||||
sender,
|
||||
Vec::from(encoding.get_ids()),
|
||||
params.top_k,
|
||||
params.top_p,
|
||||
params.temperature,
|
||||
params.repetition_penalty,
|
||||
params.frequency_penalty,
|
||||
params.seed,
|
||||
);
|
||||
|
||||
Ok(UnboundedReceiverStream::new(receiver))
|
||||
}
|
||||
|
||||
async fn health(&self, _current_health: bool) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
use thiserror::Error;
|
||||
|
||||
use text_generation_router::server;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TensorRtLlmBackendError {
|
||||
#[error("Tokenizer error: {0}")]
|
||||
Tokenizer(String),
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
//
|
||||
// Created by mfuntowicz on 6/30/24.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <limits>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "backends/trtllm/include/ffi.h"
|
||||
|
||||
|
||||
huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
|
||||
const std::string_view &engineFolder,
|
||||
const std::string_view &executorWorker
|
||||
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
|
||||
|
||||
|
||||
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
|
||||
return TensorRtLlmBackend::IsReady();
|
||||
}
|
||||
|
||||
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
|
||||
float_t frequency_penalty, uint64_t seed) {
|
||||
|
||||
// This will copy all the items from the initial slice
|
||||
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
|
||||
return TensorRtLlmBackend::Submit(
|
||||
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
|
||||
}
|
||||
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
const uint64_t requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||
|
||||
size_t numTokens = 0;
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
GenerationStep step;
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
|
||||
const auto token = decoded.outputTokenIds[0][0];
|
||||
const auto isFinal = decoded.isFinal;
|
||||
const auto logProb = decoded.logProbs.value()[0][0];
|
||||
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||
};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
const auto what = item.getErrorMsg();
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||
};
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
}
|
||||
|
||||
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
|
||||
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
|
||||
// Unconditionally call this to initialize and discover TRTLLM plugins
|
||||
InitializeBackend();
|
||||
|
||||
const auto enginePath = std::string_view(engineFolder.begin(), engineFolder.end());
|
||||
const auto executorPath = std::string_view(executorWorker.begin(), executorWorker.end());
|
||||
return std::make_unique<TensorRtLlmBackendImpl>(std::move(enginePath), std::move(executorPath));
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
pub struct GenerationStep {
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
has_error: bool,
|
||||
error_msg: String,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type GenerationContext;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
include!("backends/trtllm/src/ffi.cpp");
|
||||
|
||||
/// Represent an instance of the underlying TensorRT-LLM backend
|
||||
type TensorRtLlmBackendImpl;
|
||||
|
||||
/// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `engine_folder`: Path to the folder containing all the TRTLLM engines
|
||||
/// * `executor_worker`: Path to the TRTLLM executor worker
|
||||
///
|
||||
/// returns: <unknown>
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
/// ```
|
||||
#[rust_name = "create_tensorrt_llm_backend"]
|
||||
fn CreateTensorRtLlmBackend(
|
||||
engine_folder: &str,
|
||||
executor_worker: &str,
|
||||
) -> UniquePtr<TensorRtLlmBackendImpl>;
|
||||
|
||||
// #[rust_name = "is_ready"]
|
||||
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
|
||||
|
||||
#[rust_name = "num_responses_ready"]
|
||||
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
|
||||
|
||||
#[rust_name = "submit"]
|
||||
fn Submit(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
tokens: &[u32],
|
||||
top_k: i32,
|
||||
top_p: f32,
|
||||
temperature: f32,
|
||||
repetition_penalty: f32,
|
||||
frequency_penalty: f32,
|
||||
seed: u64,
|
||||
) -> u64;
|
||||
|
||||
#[rust_name = "stream_tokens"]
|
||||
unsafe fn StreamTokens(
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
request_id: u64,
|
||||
ctx: *mut GenerationContext,
|
||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||
) -> usize;
|
||||
|
||||
// #[rust_name = "shutdown"]
|
||||
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
|
||||
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
|
||||
use text_generation_backends_trtllm::TensorRtLlmBackend;
|
||||
use text_generation_router::server;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(long, env, required = true)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(long, env)]
|
||||
model_id: String,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env)]
|
||||
auth_token: Option<String>,
|
||||
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
|
||||
executor_worker: PathBuf,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), TensorRtLlmBackendError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
hostname,
|
||||
port,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
model_id,
|
||||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
messages_api_enabled,
|
||||
max_client_batch_size,
|
||||
auth_token,
|
||||
executor_worker,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if !executor_worker.exists() {
|
||||
return Err(TensorRtLlmBackendError::ArgumentValidation(format!(
|
||||
"`executor_work` specified path doesn't exists: {}",
|
||||
executor_worker.display()
|
||||
)));
|
||||
}
|
||||
|
||||
// Run server
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
tokenizer_name.clone(),
|
||||
Some(FromPretrainedParameters {
|
||||
revision: revision.clone().unwrap_or(String::from("main")),
|
||||
user_agent: HashMap::new(),
|
||||
auth_token,
|
||||
}),
|
||||
)
|
||||
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
|
||||
|
||||
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
None,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
false,
|
||||
None,
|
||||
None,
|
||||
messages_api_enabled,
|
||||
true,
|
||||
max_client_batch_size,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
//
|
||||
// Created by mfuntowicz on 7/2/24.
|
||||
//
|
||||
#include <catch2/catch_all.hpp>
|
||||
#include <spdlog/spdlog.h>
|
||||
#include "../include/backend.h"
|
||||
|
||||
TEST_CASE("Load TRTLLM Engine on the TGI Backend", "[trtllm][engine][load]") {
|
||||
const auto engines = std::filesystem::path("/home/mfuntowicz/.cache/huggingface/assets/trtllm/0.11.0.dev2024062500/meta-llama--Meta-Llama-3-8B-Instruct/4090/engines/");
|
||||
const auto executor = std::filesystem::path("/home/mfuntowicz/Workspace/text-generation-inference/backends/trtllm/cmake-build-debug/cmake-build-debug/_deps/trtllm-src/cpp/tensorrt_llm/executor_worker/executorWorker");
|
||||
|
||||
spdlog::info("Loading config from: {}", absolute(engines).string());
|
||||
huggingface::tgi::backends::TensorRtLlmBackend backend(engines, executor);
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
[package]
|
||||
name = "text-generation-router-v3"
|
||||
description = "Text Generation Webserver"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-router = { path = "../../router" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
grpc-metadata = { path = "../grpc-metadata" }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.11.20", features = [] }
|
||||
serde = "1.0.188"
|
||||
serde_json = "1.0.107"
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true}
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.14"
|
||||
tower-http = { version = "0.5.1", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.21.0"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "4.2.0", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "6.0.0", features = ["axum"] }
|
||||
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
|
||||
minijinja = { version = "2.0.2" }
|
||||
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }
|
||||
futures-util = "0.3.30"
|
||||
regex = "1.10.3"
|
||||
once_cell = "1.19.0"
|
||||
image = "0.25.1"
|
||||
base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
prost-build = "0.12.1"
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["text-generation-router/ngrok"]
|
||||
google = ["text-generation-router/google"]
|
||||
kserve = ["text-generation-router/kserve"]
|
|
@ -0,0 +1,19 @@
|
|||
use std::fs;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("cargo:rerun-if-changed=../../proto/");
|
||||
|
||||
fs::create_dir_all("src/client/pb").unwrap_or(());
|
||||
let mut config = prost_build::Config::new();
|
||||
config.protoc_arg("--experimental_allow_proto3_optional");
|
||||
|
||||
tonic_build::configure()
|
||||
.build_client(true)
|
||||
.build_server(false)
|
||||
.out_dir("src/client/pb")
|
||||
.include_file("mod.rs")
|
||||
.compile_with_config(config, &["../../proto/v3/generate.proto"], &["../../proto"])
|
||||
.unwrap_or_else(|e| panic!("protobuf compilation failed: {e}"));
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,501 @@
|
|||
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
|
||||
/// Batching and inference logic
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
use text_generation_router::{FinishReason, PrefillToken, Token};
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Client clone, used for health checks to skip the queue
|
||||
client: ShardedClient,
|
||||
}
|
||||
|
||||
impl BackendV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
) -> Self {
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
16,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client.clone(),
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
client,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Backend for BackendV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok(UnboundedReceiverStream::new(response_rx))
|
||||
}
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool {
|
||||
if current_health {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<crate::client::GeneratedText> for GeneratedText {
|
||||
fn from(value: crate::client::GeneratedText) -> Self {
|
||||
let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
crate::client::FinishReason::Length => FinishReason::Length,
|
||||
crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
crate::client::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,284 @@
|
|||
/// Single shard Client
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Text Generation Inference gRPC client
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Client {
|
||||
stub: TextGenerationServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
/// Returns a client connected to the given url
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let channel = Channel::builder(uri).connect().await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let channel = Channel::from_shared("http://[::]:50051".to_string())
|
||||
.unwrap()
|
||||
.connect_with_connector(tower::service_fn(move |_: Uri| {
|
||||
tokio::net::UnixStream::connect(path.clone())
|
||||
}))
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
stub: TextGenerationServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a list of uris or unix sockets of all shards
|
||||
#[instrument(skip(self))]
|
||||
pub async fn service_discovery(&mut self) -> Result<Vec<String>> {
|
||||
let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();
|
||||
let response = self.stub.service_discovery(request).await.map_err(|_| {
|
||||
ClientError::Connection("Server does not support v3 interface".to_string())
|
||||
})?;
|
||||
let urls = response
|
||||
.into_inner()
|
||||
.urls
|
||||
.into_iter()
|
||||
// Remove unix socket prefix
|
||||
.map(|url| match url.strip_prefix("unix://") {
|
||||
None => url,
|
||||
Some(stripped_url) => stripped_url.to_string(),
|
||||
})
|
||||
.collect();
|
||||
Ok(urls)
|
||||
}
|
||||
|
||||
/// Get model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let request = tonic::Request::new(InfoRequest {}).inject_context();
|
||||
let response = self.stub.info(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get model health
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let request = tonic::Request::new(HealthRequest {}).inject_context();
|
||||
let response = self.stub.health(request).await?.into_inner();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();
|
||||
self.stub.clear_cache(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let request = tonic::Request::new(FilterBatchRequest {
|
||||
batch_id,
|
||||
request_ids,
|
||||
})
|
||||
.inject_context();
|
||||
let filtered_batch = self.stub.filter_batch(request).await?.into_inner();
|
||||
Ok(filtered_batch.batch)
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip_all)]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let mut n_tokens = 0;
|
||||
let mut requests = Vec::new();
|
||||
// Create requests
|
||||
while n_tokens < max_prefill_tokens {
|
||||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut input_chunks = Vec::new();
|
||||
input_chunks
|
||||
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
|
||||
if n_tokens == 0 {
|
||||
input_chunks.push(
|
||||
Chunk::Image(Image {
|
||||
// Safe unwrap, because we control the data.
|
||||
data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),
|
||||
mimetype: "image/jpeg;base64".to_string(),
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Send stringly-typed inputs for compatibility for backends that haven't
|
||||
// been updated to support chunks.
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str(&format!(
|
||||
"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})",
|
||||
));
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
inputs,
|
||||
input_chunks: Some(Input {
|
||||
chunks: input_chunks,
|
||||
}),
|
||||
// We truncate the input on the server side to be sure that it has the correct size
|
||||
truncate,
|
||||
// Blocks and slots will be set on the server side if we use paged attention
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
top_k: 10,
|
||||
top_p: 0.9,
|
||||
typical_p: 0.9,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.2,
|
||||
frequency_penalty: 0.1,
|
||||
watermark: true,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: max_total_tokens - truncate,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: true,
|
||||
}),
|
||||
prefill_logprobs: true,
|
||||
top_n_tokens: 20,
|
||||
adapter_id: None,
|
||||
});
|
||||
n_tokens += max_input_length;
|
||||
|
||||
// Check max_batch_size
|
||||
if Some(requests.len()) == max_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let batch = Batch {
|
||||
id: 0,
|
||||
size: requests.len() as u32,
|
||||
requests,
|
||||
max_tokens: max_input_length,
|
||||
max_blocks: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
|
||||
))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
|
||||
let response = self.stub.decode(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
response.batch,
|
||||
DecodeTimings::new(
|
||||
response.concat_ns,
|
||||
response.forward_ns,
|
||||
response.decode_ns,
|
||||
response.total_ns,
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PrefillTimings {
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl PrefillTimings {
|
||||
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DecodeTimings {
|
||||
pub concat: Option<Duration>,
|
||||
pub forward: Duration,
|
||||
pub decode: Duration,
|
||||
pub total: Duration,
|
||||
}
|
||||
|
||||
impl DecodeTimings {
|
||||
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
|
||||
Self {
|
||||
concat: concat_ns.map(Duration::from_nanos),
|
||||
forward: Duration::from_nanos(forward_ns),
|
||||
decode: Duration::from_nanos(decode_ns),
|
||||
total: Duration::from_nanos(total_ns),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
//! Text Generation gRPC client library
|
||||
|
||||
use async_trait::async_trait;
|
||||
use thiserror::Error;
|
||||
use tonic::transport;
|
||||
use tonic::Status;
|
||||
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
mod pb;
|
||||
|
||||
mod grpc_client;
|
||||
mod sharded_client;
|
||||
|
||||
pub use grpc_client::Client;
|
||||
pub use pb::generate::v3::{
|
||||
input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||
HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,
|
||||
StoppingCriteriaParameters,
|
||||
};
|
||||
pub use sharded_client::ShardedClient;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Health {
|
||||
/// Check if a generate server is healthy by asking it to allocate a tensor on device
|
||||
async fn device_health(&self) -> Result<()>;
|
||||
|
||||
/// Check if a generate server is healthy by doing a forward pass.
|
||||
/// EXPENSIVE
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
Connection(String),
|
||||
#[error("Server error: {0}")]
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
fn from(err: Status) -> Self {
|
||||
let err = Self::Generation(err.message().to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
impl From<transport::Error> for ClientError {
|
||||
fn from(err: transport::Error) -> Self {
|
||||
let err = Self::Connection(err.to_string());
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
}
|
||||
}
|
||||
|
||||
// Small convenience re-wrapping of `Chunk`.
|
||||
impl From<Chunk> for InputChunk {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
InputChunk { chunk: Some(chunk) }
|
||||
}
|
||||
}
|
||||
|
||||
static WARMUP_IMAGE_BASE64 :&str = "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=";
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ClientError>;
|
|
@ -0,0 +1,260 @@
|
|||
use crate::client::{ClientError, Result};
|
||||
/// Multi shard Client
|
||||
use crate::client::{Health, ShardInfo};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,
|
||||
NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use crate::client::{Chunk, InfoResponse, Input};
|
||||
use async_trait::async_trait;
|
||||
use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
/// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.
|
||||
async fn from_master_client(mut master_client: Client) -> Result<Self> {
|
||||
// Get all uris/unix sockets from the master client
|
||||
let uris = master_client.service_discovery().await?;
|
||||
let futures = uris.into_iter().map(Client::connect_uds);
|
||||
let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();
|
||||
Ok(Self::new(clients?))
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given uri
|
||||
#[allow(dead_code)]
|
||||
pub async fn connect(uri: Uri) -> Result<Self> {
|
||||
let master_client = Client::connect(uri).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Returns a client connected to the given unix socket
|
||||
pub async fn connect_uds(path: String) -> Result<Self> {
|
||||
let master_client = Client::connect_uds(path).await?;
|
||||
Self::from_master_client(master_client).await
|
||||
}
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
#[instrument(skip(self))]
|
||||
pub async fn health(&mut self) -> Result<HealthResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.health())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Clear the past generations cache
|
||||
#[instrument(skip(self))]
|
||||
pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.clear_cache(batch_id))
|
||||
.collect();
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Filter a cached batch
|
||||
#[instrument(skip(self))]
|
||||
pub async fn filter_batch(
|
||||
&mut self,
|
||||
batch_id: u64,
|
||||
request_ids: Vec<u64>,
|
||||
) -> Result<Option<CachedBatch>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))
|
||||
.collect();
|
||||
// all shards return the same message
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// Warmup on a max size batch
|
||||
///
|
||||
/// Returns the maximum amount of tokens supported by the hardware
|
||||
#[instrument(skip(self))]
|
||||
pub async fn warmup(
|
||||
&mut self,
|
||||
max_input_length: u32,
|
||||
max_prefill_tokens: u32,
|
||||
max_total_tokens: u32,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<Option<u32>> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| {
|
||||
Box::pin(client.warmup(
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_size,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
// Take the minimum value
|
||||
let results = join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<Option<u32>>>>()?;
|
||||
Ok(results.into_iter().flatten().min())
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given batch
|
||||
///
|
||||
/// Returns Generation for each request in batch
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
|
||||
pub async fn prefill(
|
||||
&mut self,
|
||||
batch: Batch,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.prefill(batch.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
|
||||
/// Generate one token for each request in the given cached batches
|
||||
///
|
||||
/// Returns Generation for each request in batches
|
||||
/// and the next cached batch
|
||||
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
|
||||
pub async fn decode(
|
||||
&mut self,
|
||||
batches: Vec<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| Box::pin(client.decode(batches.clone())))
|
||||
.collect();
|
||||
#[allow(clippy::type_complexity)]
|
||||
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
|
||||
join_all(futures).await.into_iter().collect();
|
||||
let mut results = results?;
|
||||
|
||||
let (mut generations, next_batch, mut timings) =
|
||||
results.pop().ok_or(ClientError::EmptyResults)?;
|
||||
|
||||
// Merge generations from different model shards
|
||||
for (mut shard_generations, _, shard_timings) in results.into_iter() {
|
||||
generations.append(&mut shard_generations);
|
||||
// Return the timings of the slowest shard
|
||||
if shard_timings.total > timings.total {
|
||||
timings = shard_timings;
|
||||
}
|
||||
}
|
||||
Ok((generations, next_batch, timings))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
self.clone().health().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
input_chunks: Some(Input {
|
||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||
}),
|
||||
truncate: 10,
|
||||
prefill_logprobs: false,
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
top_p: 1.0,
|
||||
typical_p: 1.0,
|
||||
do_sample: false,
|
||||
seed: 0,
|
||||
repetition_penalty: 1.0,
|
||||
frequency_penalty: 0.0,
|
||||
watermark: false,
|
||||
grammar: String::new(),
|
||||
grammar_type: GrammarType::None as i32,
|
||||
}),
|
||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||
max_new_tokens: 1,
|
||||
stop_sequences: vec![],
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
requests: vec![liveness_request],
|
||||
size: 1,
|
||||
max_tokens: 2,
|
||||
max_blocks: 1,
|
||||
};
|
||||
self.clone().prefill(batch).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,142 @@
|
|||
mod backend;
|
||||
mod block_allocator;
|
||||
mod client;
|
||||
mod queue;
|
||||
|
||||
use crate::client::{ClientError, ShardedClient};
|
||||
pub(crate) use backend::BackendV3;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, ToSchema)]
|
||||
pub struct BackendInfo {
|
||||
/// Mandatory
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
|
||||
/// Backend parameters
|
||||
#[schema(example = "1")]
|
||||
pub speculate: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn connect_backend(
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
master_shard_uds_path: String,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
) -> Result<(BackendV3, BackendInfo), V3Error> {
|
||||
// Helper function
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens
|
||||
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(V3Error::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(V3Error::Connection)?;
|
||||
|
||||
// server is running on v3
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(V3Error::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(V3Error::Warmup)?,
|
||||
)?;
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let backend_info = BackendInfo {
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
||||
Ok((backend, backend_info))
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum V3Error {
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
use clap::{Parser, Subcommand};
|
||||
use text_generation_router::server;
|
||||
use text_generation_router_v3::{connect_backend, V3Error};
|
||||
use thiserror::Error;
|
||||
|
||||
/// App Configuration
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
max_best_of: usize,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_stop_sequences: usize,
|
||||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
waiting_served_ratio: f32,
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
#[clap(long, env)]
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
#[clap(default_value = "20", long, env)]
|
||||
max_waiting_tokens: usize,
|
||||
#[clap(long, env)]
|
||||
max_batch_size: Option<usize>,
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
hostname: String,
|
||||
#[clap(default_value = "3000", long, short, env)]
|
||||
port: u16,
|
||||
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||
master_shard_uds_path: String,
|
||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(long, env)]
|
||||
tokenizer_config_path: Option<String>,
|
||||
#[clap(long, env)]
|
||||
revision: Option<String>,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
validation_workers: usize,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(default_value = "text-generation-inference.router", long, env)]
|
||||
otlp_service_name: String,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_edge: Option<String>,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
messages_api_enabled: bool,
|
||||
#[clap(long, env, default_value_t = false)]
|
||||
disable_grammar_support: bool,
|
||||
#[clap(default_value = "4", long, env)]
|
||||
max_client_batch_size: usize,
|
||||
#[clap(long, env, default_value_t)]
|
||||
disable_usage_stats: bool,
|
||||
#[clap(long, env, default_value_t)]
|
||||
disable_crash_reports: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
command,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
hostname,
|
||||
port,
|
||||
master_shard_uds_path,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
validation_workers,
|
||||
api_key,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
disable_usage_stats,
|
||||
disable_crash_reports,
|
||||
max_client_batch_size,
|
||||
} = args;
|
||||
|
||||
if let Some(Commands::PrintSchema) = command {
|
||||
use utoipa::OpenApi;
|
||||
let api_doc = text_generation_router::server::ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
};
|
||||
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
tokenizer_name,
|
||||
tokenizer_config_path,
|
||||
revision,
|
||||
hostname,
|
||||
port,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
disable_usage_stats,
|
||||
disable_crash_reports,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Argument validation error: {0}")]
|
||||
ArgumentValidation(String),
|
||||
#[error("Backend failed: {0}")]
|
||||
Backend(#[from] V3Error),
|
||||
#[error("WebServer error: {0}")]
|
||||
WebServer(#[from] server::WebServerError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
}
|
|
@ -1,17 +1,17 @@
|
|||
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::infer::InferError;
|
||||
use crate::infer::InferStreamResponse;
|
||||
use crate::validation::{
|
||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||
use crate::block_allocator::{BlockAllocation, BlockAllocator};
|
||||
use crate::client;
|
||||
use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_client::v3::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,
|
||||
ValidStoppingParameters,
|
||||
};
|
||||
use text_generation_client::ChunksToString;
|
||||
use text_generation_client::Input;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
@ -337,8 +337,22 @@ impl State {
|
|||
batch_requests.push(Request {
|
||||
id,
|
||||
prefill_logprobs: entry.request.decoder_input_details,
|
||||
input_chunks: Some(Input {
|
||||
chunks: entry.request.inputs.clone(),
|
||||
input_chunks: Some(client::Input {
|
||||
chunks: entry
|
||||
.request
|
||||
.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|c| client::InputChunk {
|
||||
chunk: Some(match c {
|
||||
Chunk::Text(text) => client::Chunk::Text(text),
|
||||
Chunk::Image(image) => client::Chunk::Image(client::Image {
|
||||
data: image.data,
|
||||
mimetype: image.mimetype,
|
||||
}),
|
||||
}),
|
||||
})
|
||||
.collect(),
|
||||
}),
|
||||
inputs: entry.request.inputs.chunks_to_string(),
|
||||
truncate: entry.request.truncate,
|
|
@ -21,7 +21,7 @@ float-ord = "0.3.2"
|
|||
serde = {version = "1.0.188", features = ["derive"]}
|
||||
serde_json = "1.0"
|
||||
tabled = "0.14.0"
|
||||
text-generation-client = { path = "../router/client" }
|
||||
text-generation-client = { path = "../backends/client" }
|
||||
thiserror = "1.0.48"
|
||||
tokenizers = { workspace = true }
|
||||
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
|
||||
|
|
|
@ -1580,16 +1580,11 @@
|
|||
"type": "object",
|
||||
"required": [
|
||||
"model_id",
|
||||
"model_dtype",
|
||||
"model_device_type",
|
||||
"max_concurrent_requests",
|
||||
"max_best_of",
|
||||
"max_stop_sequences",
|
||||
"max_input_tokens",
|
||||
"max_total_tokens",
|
||||
"waiting_served_ratio",
|
||||
"max_batch_total_tokens",
|
||||
"max_waiting_tokens",
|
||||
"validation_workers",
|
||||
"max_client_batch_size",
|
||||
"router",
|
||||
|
@ -1601,18 +1596,6 @@
|
|||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"max_batch_size": {
|
||||
"type": "integer",
|
||||
"example": "null",
|
||||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"max_batch_total_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"example": "32000",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_best_of": {
|
||||
"type": "integer",
|
||||
"example": "2",
|
||||
|
@ -1644,19 +1627,6 @@
|
|||
"example": "2048",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_waiting_tokens": {
|
||||
"type": "integer",
|
||||
"example": "20",
|
||||
"minimum": 0
|
||||
},
|
||||
"model_device_type": {
|
||||
"type": "string",
|
||||
"example": "cuda"
|
||||
},
|
||||
"model_dtype": {
|
||||
"type": "string",
|
||||
"example": "torch.float16"
|
||||
},
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"description": "Model info",
|
||||
|
@ -1690,11 +1660,6 @@
|
|||
"version": {
|
||||
"type": "string",
|
||||
"example": "0.5.0"
|
||||
},
|
||||
"waiting_served_ratio": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"example": "1.2"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -7,25 +7,18 @@ edition.workspace = true
|
|||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "text-generation-router"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.74"
|
||||
async-stream = "0.3.5"
|
||||
axum = { version = "0.7", features = ["json"] }
|
||||
axum-tracing-opentelemetry = "0.16"
|
||||
text-generation-client = { path = "client" }
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
futures = "0.3.28"
|
||||
hf-hub = { workspace = true }
|
||||
itertools = "0.10"
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.23.0"
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { workspace = true }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = "0.13.0"
|
||||
|
@ -55,6 +48,7 @@ base64 = { workspace = true }
|
|||
sysinfo = "0.30.13"
|
||||
uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] }
|
||||
csv = "1.3.0"
|
||||
ureq = "=2.9"
|
||||
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -1 +0,0 @@
|
|||
*
|
|
@ -1,528 +1,85 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v3::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
use crate::infer::InferError;
|
||||
use crate::{
|
||||
ChatTemplateInputs, GrammarType, Message, MessageChunk, TextMessage, TokenizerConfigToken,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
use nohash_hasher::IntMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use text_generation_client::v3::{Batch, CachedBatch, Generation, ShardedClient};
|
||||
use text_generation_client::ClientError;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
pub(crate) struct SchedulerV3 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
impl SchedulerV3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
template: String,
|
||||
bos_token: Option<TokenizerConfigToken>,
|
||||
eos_token: Option<TokenizerConfigToken>,
|
||||
) -> Self {
|
||||
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
|
||||
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let block_size = if flashdecoding { 256 } else { 16 };
|
||||
let queue = Queue::new(
|
||||
requires_padding,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// Spawn batching background task that contains all the inference logic
|
||||
tokio::spawn(batching_task(
|
||||
client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
generation_health,
|
||||
));
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
queue,
|
||||
batching_task_notifier,
|
||||
}
|
||||
template,
|
||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV3 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
pub(crate) fn apply(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// MPSC channel to communicate with the background batching task
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
let input_length = request.input_length;
|
||||
|
||||
// Append the request to the queue
|
||||
self.queue.append(Entry {
|
||||
request,
|
||||
response_tx,
|
||||
span: Span::current(),
|
||||
temp_span: None,
|
||||
queue_time: Instant::now(),
|
||||
batch_time: None,
|
||||
block_allocation: None,
|
||||
});
|
||||
|
||||
// Notify the background task that we have a new entry in the queue that needs
|
||||
// to be batched
|
||||
self.batching_task_notifier.notify_one();
|
||||
|
||||
// Return stream
|
||||
Ok((
|
||||
permit,
|
||||
input_length,
|
||||
UnboundedReceiverStream::new(response_rx),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
/// Batches requests and sends them to the inference server
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn batching_task(
|
||||
mut client: ShardedClient,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) {
|
||||
// Infinite loop
|
||||
loop {
|
||||
// Wait for a notification from the Infer struct
|
||||
notifier.notified().await;
|
||||
|
||||
// Get the next batch from the queue
|
||||
// This batch might be smaller than the maximum batch size if there are not enough requests
|
||||
// waiting in the queue
|
||||
while let Some((mut entries, batch, span)) = queue
|
||||
.next_batch(
|
||||
None,
|
||||
max_batch_size,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
let mut waiting_tokens = 1;
|
||||
|
||||
// We loop until we do not receive any cached batch from the inference server (== until
|
||||
// all requests have met their stopping criteria)
|
||||
while let Some(batch) = cached_batch {
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
None
|
||||
} else {
|
||||
// Minimum batch size
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize);
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to add the info that this entry is waiting
|
||||
// because a new batch is being computed
|
||||
let entry_waiting_span = info_span!(parent: &entry.span, "waiting");
|
||||
// Add relationships
|
||||
span.follows_from(&entry_waiting_span);
|
||||
entry_waiting_span.follows_from(&span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_waiting_span);
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, &mut new_entries, &generation_health)
|
||||
.instrument(span)
|
||||
.await;
|
||||
// Reset waiting counter
|
||||
waiting_tokens = 1;
|
||||
// Extend current batch with the new batch
|
||||
if let Some(new_cached_batch) = new_cached_batch {
|
||||
entries.extend(new_entries);
|
||||
batches.push(new_cached_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// Create span for this batch to add context to inference calls
|
||||
let next_batch_size = entries.len();
|
||||
let next_batch_span =
|
||||
info_span!(parent: None, "batch", batch_size = next_batch_size);
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
next_batch_span.follows_from(&entry_batch_span);
|
||||
entry_batch_span.follows_from(&next_batch_span);
|
||||
// Update entry
|
||||
entry.temp_span = Some(entry_batch_span);
|
||||
});
|
||||
|
||||
cached_batch = decode(&mut client, batches, &mut entries, &generation_health)
|
||||
.instrument(next_batch_span)
|
||||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn prefill(
|
||||
client: &mut ShardedClient,
|
||||
batch: Batch,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
// Update health
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn decode(
|
||||
client: &mut ShardedClient,
|
||||
batches: Vec<CachedBatch>,
|
||||
entries: &mut IntMap<u64, Entry>,
|
||||
generation_health: &Arc<AtomicBool>,
|
||||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
// Update health
|
||||
generation_health.store(true, Ordering::SeqCst);
|
||||
|
||||
let start_filtering_time = Instant::now();
|
||||
// Send generated tokens and filter stopped entries
|
||||
filter_send_generations(generations, entries);
|
||||
|
||||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
Err(err) => {
|
||||
generation_health.store(false, Ordering::SeqCst);
|
||||
for id in batch_ids {
|
||||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter a `batch` and remove all requests not present in `entries`
|
||||
#[instrument(skip_all)]
|
||||
async fn filter_batch(
|
||||
client: &mut ShardedClient,
|
||||
next_batch: Option<CachedBatch>,
|
||||
entries: &IntMap<u64, Entry>,
|
||||
) -> Option<CachedBatch> {
|
||||
let mut batch = next_batch?;
|
||||
|
||||
// No need to filter
|
||||
if batch.size as usize == entries.len() {
|
||||
return Some(batch);
|
||||
}
|
||||
|
||||
let id = batch.id;
|
||||
|
||||
// Retain only requests that are still in entries
|
||||
batch.request_ids.retain(|id| entries.contains_key(id));
|
||||
|
||||
if batch.request_ids.is_empty() {
|
||||
// All requests have been filtered out
|
||||
// Next batch is now empty
|
||||
// Clear it from the Python shards cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.clear_cache(Some(id)).await.unwrap();
|
||||
None
|
||||
} else {
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client.filter_batch(id, batch.request_ids).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Send one or multiple `InferStreamResponse` to Infer for all `entries`
|
||||
/// and filter entries
|
||||
#[instrument(skip_all)]
|
||||
fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {
|
||||
generations.into_iter().for_each(|generation| {
|
||||
let id = generation.request_id;
|
||||
// Get entry
|
||||
// We can `expect` here as the request id should always be in the entries
|
||||
let entry = entries
|
||||
.get(&id)
|
||||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_generation", generation = ?generation).entered();
|
||||
// Send generation responses back to the infer task
|
||||
// If the receive an error from the Flume channel, it means that the client dropped the
|
||||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
});
|
||||
}
|
||||
|
||||
/// Send responses through the `entry` response channel
|
||||
fn send_responses(
|
||||
generation: Generation,
|
||||
entry: &Entry,
|
||||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
let mut stopped = false;
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
if let Some(prefill_tokens) = generation.prefill_tokens {
|
||||
// Create Token objects
|
||||
// We do that here instead of in the Python code as Rust for loops are faster
|
||||
let prefill_tokens = prefill_tokens
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(prefill_tokens.logprobs)
|
||||
.zip(prefill_tokens.texts)
|
||||
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
|
||||
.collect();
|
||||
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;
|
||||
}
|
||||
|
||||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
.zip(tokens_.logprobs)
|
||||
.zip(tokens_.texts)
|
||||
.zip(tokens_.is_special)
|
||||
.enumerate()
|
||||
.peekable();
|
||||
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
|
||||
let token = Token {
|
||||
id,
|
||||
text,
|
||||
logprob,
|
||||
special,
|
||||
};
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
.iter()
|
||||
.zip(top_tokens_.logprobs.iter())
|
||||
.zip(top_tokens_.texts.iter())
|
||||
.zip(top_tokens_.is_special.iter())
|
||||
.map(|(((&id, &logprob), text), &special)| Token {
|
||||
id,
|
||||
text: text.to_string(),
|
||||
logprob,
|
||||
special,
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
match (&generation.generated_text, iterator.peek()) {
|
||||
(Some(generated_text), None) => {
|
||||
// Generation has ended
|
||||
stopped = true;
|
||||
// Send message
|
||||
entry.response_tx.send(Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens,
|
||||
generated_text: GeneratedText::from(generated_text.clone()),
|
||||
queued: entry.queue_time,
|
||||
start: entry.batch_time.unwrap(),
|
||||
}))?;
|
||||
}
|
||||
_ => {
|
||||
// Send message
|
||||
entry
|
||||
.response_tx
|
||||
.send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stopped)
|
||||
}
|
||||
|
||||
/// Send errors to Infer for all `entries`
|
||||
#[instrument(skip_all)]
|
||||
fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
||||
entries.drain().for_each(|(_, entry)| {
|
||||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
entry
|
||||
.response_tx
|
||||
.send(Err(err))
|
||||
.unwrap_or(());
|
||||
});
|
||||
}
|
||||
|
||||
impl From<text_generation_client::v3::GeneratedText> for GeneratedText {
|
||||
fn from(value: text_generation_client::v3::GeneratedText) -> Self {
|
||||
let v3_finish_reason =
|
||||
text_generation_client::v3::FinishReason::try_from(value.finish_reason).unwrap();
|
||||
let finish_reason = match v3_finish_reason {
|
||||
text_generation_client::v3::FinishReason::Length => FinishReason::Length,
|
||||
text_generation_client::v3::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
|
||||
text_generation_client::v3::FinishReason::StopSequence => FinishReason::StopSequence,
|
||||
};
|
||||
|
||||
Self {
|
||||
text: value.text,
|
||||
generated_tokens: value.generated_tokens,
|
||||
finish_reason,
|
||||
seed: value.seed,
|
||||
}
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
// tests
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::infer::raise_exception;
|
||||
use crate::infer::chat_template::raise_exception;
|
||||
use crate::{ChatTemplateInputs, TextMessage};
|
||||
use minijinja::Environment;
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::Health;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct HealthCheck {
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl HealthCheck {
|
||||
pub(crate) fn new(
|
||||
client: Arc<dyn Health + Send + Sync>,
|
||||
generation_health: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
generation_health,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn check(&mut self) -> bool {
|
||||
let value = if self.generation_health.load(Ordering::SeqCst) {
|
||||
// Generation is healthy, we only check that the shards can allocate on device
|
||||
self.client.device_health().await
|
||||
} else {
|
||||
self.client.model_health().await
|
||||
}
|
||||
.is_ok();
|
||||
// Update generation health
|
||||
self.generation_health.store(value, Ordering::SeqCst);
|
||||
value
|
||||
}
|
||||
}
|
|
@ -1,23 +1,18 @@
|
|||
mod health;
|
||||
pub(crate) mod v2;
|
||||
pub(crate) mod v3;
|
||||
|
||||
pub(crate) use health::HealthCheck;
|
||||
// pub(crate) mod v2;
|
||||
mod chat_template;
|
||||
pub mod tool_grammar;
|
||||
|
||||
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
|
||||
use crate::GrammarType;
|
||||
use crate::{
|
||||
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
|
||||
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
|
||||
};
|
||||
use crate::{
|
||||
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
|
||||
ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,
|
||||
Message, PrefillToken, Token,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chat_template::ChatTemplate;
|
||||
use futures::future::try_join_all;
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use minijinja::ErrorKind;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
|
||||
|
@ -26,12 +21,14 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
|||
use tokio_stream::StreamExt;
|
||||
use tracing::instrument;
|
||||
|
||||
pub(crate) trait Scheduler {
|
||||
#[async_trait]
|
||||
pub trait Backend {
|
||||
fn schedule(
|
||||
&self,
|
||||
request: ValidGenerateRequest,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> Result<GenerateStreamResponse, InferError>;
|
||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||
|
||||
async fn health(&self, current_health: bool) -> bool;
|
||||
}
|
||||
|
||||
/// Inference struct
|
||||
|
@ -39,18 +36,20 @@ pub(crate) trait Scheduler {
|
|||
pub struct Infer {
|
||||
/// Validation
|
||||
validation: Validation,
|
||||
/// Request scheduler
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
/// Request backend
|
||||
backend: Arc<dyn Backend + Send + Sync>,
|
||||
/// Chat template
|
||||
chat_template: Option<ChatTemplate>,
|
||||
/// Inference limit
|
||||
limit_concurrent_requests: Arc<Semaphore>,
|
||||
/// Backend health
|
||||
backend_health: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Infer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
scheduler: Arc<dyn Scheduler + Send + Sync>,
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
validation: Validation,
|
||||
max_concurrent_requests: usize,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
|
@ -71,18 +70,22 @@ impl Infer {
|
|||
// Inference limit with a semaphore
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||
|
||||
// Backend health
|
||||
let backend_health = Arc::new(AtomicBool::new(false));
|
||||
|
||||
Self {
|
||||
validation,
|
||||
scheduler,
|
||||
backend: Arc::new(backend),
|
||||
chat_template,
|
||||
limit_concurrent_requests: semaphore,
|
||||
backend_health,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new request to the queue and return a stream of InferStreamResponse
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn generate_stream(
|
||||
&self,
|
||||
pub(crate) async fn generate_stream<'a>(
|
||||
&'a self,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStreamResponse, InferError> {
|
||||
// Limit concurrent requests by acquiring a permit from the semaphore
|
||||
|
@ -103,7 +106,10 @@ impl Infer {
|
|||
err
|
||||
})?;
|
||||
|
||||
self.scheduler.schedule(valid_request, permit)
|
||||
let input_length = valid_request.input_length;
|
||||
let generation_stream = self.backend.schedule(valid_request)?;
|
||||
|
||||
Ok((permit, input_length, generation_stream))
|
||||
}
|
||||
|
||||
/// Tokenizer the input
|
||||
|
@ -155,7 +161,7 @@ impl Infer {
|
|||
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
|
||||
|
||||
// Create stream and keep semaphore permit as long as generate lives
|
||||
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
|
||||
let (_permit, _input_length, stream) = self.generate_stream(request).await?;
|
||||
|
||||
// Return values
|
||||
let mut result_prefill = Vec::new();
|
||||
|
@ -165,6 +171,8 @@ impl Infer {
|
|||
let mut result_start = None;
|
||||
let mut result_queued = None;
|
||||
|
||||
let mut stream = Box::pin(stream);
|
||||
|
||||
// Iterate on stream
|
||||
while let Some(response) = stream.next().await {
|
||||
match response? {
|
||||
|
@ -256,207 +264,15 @@ impl Infer {
|
|||
let best_response = infer_responses.remove(max_index);
|
||||
Ok((best_response, infer_responses))
|
||||
}
|
||||
}
|
||||
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ChatTemplate {
|
||||
template: Template<'static, 'static>,
|
||||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
fn new(
|
||||
template: String,
|
||||
bos_token: Option<TokenizerConfigToken>,
|
||||
eos_token: Option<TokenizerConfigToken>,
|
||||
) -> Self {
|
||||
let mut env = Box::new(Environment::new());
|
||||
// enable things like .strip() or .capitalize()
|
||||
env.set_unknown_method_callback(pycompat::unknown_method_callback);
|
||||
let template_str = template.into_boxed_str();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
// check if contains the tools variable within the template
|
||||
let use_default_tool_template =
|
||||
!template_str.as_ref().replace(' ', "").contains("{{tools}}");
|
||||
// leaking env and template_str as read-only, static resources for performance.
|
||||
let template = Box::leak(env)
|
||||
.template_from_str(Box::leak(template_str))
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
template,
|
||||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply(
|
||||
&self,
|
||||
mut messages: Vec<Message>,
|
||||
grammar_with_prompt: Option<(GrammarType, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
if self.use_default_tool_template {
|
||||
if let Some(last_message) = messages.last_mut() {
|
||||
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
|
||||
last_message.content.push(MessageChunk::Text {
|
||||
text: format!("\n---\n{}\n{}", tool_prompt, tools),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
.render(ChatTemplateInputs {
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
tools: None,
|
||||
tools_prompt: None,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
// find a tool by name
|
||||
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == name)
|
||||
.cloned()
|
||||
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Some(tools))
|
||||
#[instrument(skip(self))]
|
||||
pub(crate) async fn health(&self) -> bool {
|
||||
let health = self
|
||||
.backend
|
||||
.health(self.backend_health.load(Ordering::SeqCst))
|
||||
.await;
|
||||
self.backend_health.store(health, Ordering::SeqCst);
|
||||
health
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -468,15 +284,15 @@ pub(crate) type GenerateStreamResponse = (
|
|||
);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct GeneratedText {
|
||||
pub(crate) text: String,
|
||||
pub(crate) generated_tokens: u32,
|
||||
pub(crate) finish_reason: FinishReason,
|
||||
pub(crate) seed: Option<u64>,
|
||||
pub struct GeneratedText {
|
||||
pub text: String,
|
||||
pub generated_tokens: u32,
|
||||
pub finish_reason: FinishReason,
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum InferStreamResponse {
|
||||
pub enum InferStreamResponse {
|
||||
// Optional first message
|
||||
Prefill(Vec<PrefillToken>),
|
||||
// Intermediate messages
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
use crate::infer::InferError;
|
||||
use crate::{FunctionRef, FunctionsMap, Properties, Tool, ToolChoice, ToolType, Tools};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(crate) struct ToolGrammar {}
|
||||
|
||||
impl ToolGrammar {
|
||||
// find a tool by name
|
||||
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool.function.name == name)
|
||||
.cloned()
|
||||
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
tools: Option<Vec<Tool>>,
|
||||
tool_choice: ToolChoice,
|
||||
) -> Result<Option<Tools>, InferError> {
|
||||
// if no tools are provided, we return None
|
||||
let tools = match tools {
|
||||
Some(tools) if !tools.is_empty() => tools,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);
|
||||
|
||||
// if tools are provided and no tool_choice we default to the OneOf
|
||||
let tools_to_use = match tool_choice {
|
||||
ToolType::FunctionName(name) => {
|
||||
vec![Self::find_tool_by_name(&tools, &name)?]
|
||||
}
|
||||
ToolType::Function { function } => {
|
||||
vec![Self::find_tool_by_name(&tools, &function.name)?]
|
||||
}
|
||||
ToolType::OneOf => tools,
|
||||
ToolType::NoTool => return Ok(None),
|
||||
};
|
||||
|
||||
// adds the error notification function for LLM feedback if required
|
||||
let mut text_response_properties = Map::new();
|
||||
text_response_properties.insert(
|
||||
"error".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": "The error or issue to notify"
|
||||
}),
|
||||
);
|
||||
text_response_properties.insert(
|
||||
"_name".to_string(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"const": "notify_error"
|
||||
}),
|
||||
);
|
||||
|
||||
let functions: HashMap<String, serde_json::Value> = tools_to_use
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
let func = tool.function.clone();
|
||||
|
||||
// Clone the existing parameters, which are expected to be a JSON object
|
||||
let mut params = if let Value::Object(params) = &func.arguments {
|
||||
params.clone()
|
||||
} else {
|
||||
Map::new()
|
||||
};
|
||||
|
||||
// Insert the function's description at the top level, outside of properties
|
||||
params.insert(
|
||||
"description".to_string(),
|
||||
Value::String(func.description.clone().unwrap_or_default()),
|
||||
);
|
||||
|
||||
// Ensure 'properties' exists and is an object
|
||||
let properties = params
|
||||
.entry("properties".to_string())
|
||||
.or_insert_with(|| json!({}))
|
||||
.as_object_mut()
|
||||
.unwrap();
|
||||
|
||||
// Insert the constant for the function name inside 'properties'
|
||||
properties.insert(
|
||||
"_name".to_string(),
|
||||
json!({
|
||||
"type": "string",
|
||||
"const": func.name.clone(),
|
||||
// "description": "The name of the function"
|
||||
}),
|
||||
);
|
||||
|
||||
// Check if 'required' exists, and it is an array. If not, create an empty array.
|
||||
let required = params
|
||||
.entry("required".to_string())
|
||||
.or_insert_with(|| json!([]))
|
||||
.as_array_mut()
|
||||
.unwrap();
|
||||
|
||||
// Add 'name' to the 'required' array if it is not already present
|
||||
if !required.iter().any(|r| r == "_name") {
|
||||
required.push(json!("_name"));
|
||||
}
|
||||
|
||||
(func.name, Value::Object(params))
|
||||
})
|
||||
.chain([(
|
||||
"notify_error".to_string(),
|
||||
serde_json::json!({
|
||||
"properties": text_response_properties,
|
||||
"required": ["error", "_name"],
|
||||
"type": "object"
|
||||
}),
|
||||
)])
|
||||
.collect();
|
||||
|
||||
let tools = Tools {
|
||||
functions_map: FunctionsMap { functions },
|
||||
properties: Properties {
|
||||
function: tools_to_use
|
||||
.iter()
|
||||
.map(|tool| FunctionRef {
|
||||
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
|
||||
})
|
||||
.chain(std::iter::once(FunctionRef {
|
||||
ref_path: "#/$functions/notify_error".to_string(),
|
||||
}))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Some(tools))
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::SchedulerV2;
|
||||
pub(crate) use scheduler::BackendV2;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/// Batching and inference logic
|
||||
use crate::infer::v2::queue::{Entry, Queue};
|
||||
use crate::infer::{
|
||||
GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse, Scheduler,
|
||||
Backend, GenerateStreamResponse, GeneratedText, InferError, InferStreamResponse,
|
||||
};
|
||||
use crate::validation::ValidGenerateRequest;
|
||||
use crate::{FinishReason, PrefillToken, Token};
|
||||
|
@ -18,14 +18,14 @@ use tokio::time::Instant;
|
|||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub(crate) struct SchedulerV2 {
|
||||
pub(crate) struct BackendV2 {
|
||||
/// Request queue
|
||||
queue: Queue,
|
||||
/// Notify batcher on queue appends
|
||||
batching_task_notifier: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl SchedulerV2 {
|
||||
impl BackendV2 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn new(
|
||||
client: ShardedClient,
|
||||
|
@ -69,7 +69,7 @@ impl SchedulerV2 {
|
|||
}
|
||||
}
|
||||
|
||||
impl Scheduler for SchedulerV2 {
|
||||
impl Backend for BackendV2 {
|
||||
#[instrument(skip_all)]
|
||||
fn schedule(
|
||||
&self,
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
mod block_allocator;
|
||||
mod queue;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) use scheduler::SchedulerV3;
|
|
@ -1,11 +1,12 @@
|
|||
/// Text Generation Inference Webserver
|
||||
pub mod config;
|
||||
mod infer;
|
||||
pub mod infer;
|
||||
pub mod server;
|
||||
mod validation;
|
||||
pub mod validation;
|
||||
|
||||
#[cfg(feature = "kserve")]
|
||||
mod kserve;
|
||||
pub mod logging;
|
||||
|
||||
pub mod usage_stats;
|
||||
|
||||
|
@ -148,12 +149,13 @@ pub struct Info {
|
|||
pub model_id: String,
|
||||
#[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")]
|
||||
pub model_sha: Option<String>,
|
||||
#[schema(example = "torch.float16")]
|
||||
pub model_dtype: String,
|
||||
#[schema(example = "cuda")]
|
||||
pub model_device_type: String,
|
||||
// #[schema(example = "torch.float16")]
|
||||
// pub model_dtype: String,
|
||||
// #[schema(example = "cuda")]
|
||||
// pub model_device_type: String,
|
||||
#[schema(nullable = true, example = "text-generation")]
|
||||
pub model_pipeline_tag: Option<String>,
|
||||
|
||||
/// Router Parameters
|
||||
#[schema(example = "128")]
|
||||
pub max_concurrent_requests: usize,
|
||||
|
@ -165,18 +167,11 @@ pub struct Info {
|
|||
pub max_input_tokens: usize,
|
||||
#[schema(example = "2048")]
|
||||
pub max_total_tokens: usize,
|
||||
#[schema(example = "1.2")]
|
||||
pub waiting_served_ratio: f32,
|
||||
#[schema(example = "32000")]
|
||||
pub max_batch_total_tokens: u32,
|
||||
#[schema(example = "20")]
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "2")]
|
||||
pub validation_workers: usize,
|
||||
#[schema(example = "32")]
|
||||
pub max_client_batch_size: usize,
|
||||
|
||||
/// Router Info
|
||||
#[schema(example = "text-generation-router")]
|
||||
pub router: &'static str,
|
||||
|
@ -1068,23 +1063,23 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
pub struct PrefillToken {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema, Clone)]
|
||||
pub struct Token {
|
||||
#[schema(example = 0)]
|
||||
id: u32,
|
||||
pub id: u32,
|
||||
#[schema(example = "test")]
|
||||
text: String,
|
||||
pub text: String,
|
||||
#[schema(nullable = true, example = - 0.34)]
|
||||
logprob: f32,
|
||||
pub logprob: f32,
|
||||
#[schema(example = "false")]
|
||||
special: bool,
|
||||
pub special: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, ToSchema)]
|
||||
|
@ -1102,7 +1097,7 @@ pub struct SimpleToken {
|
|||
#[derive(Debug, Serialize, ToSchema)]
|
||||
#[serde(rename_all(serialize = "snake_case"))]
|
||||
#[schema(example = "Length")]
|
||||
pub(crate) enum FinishReason {
|
||||
pub enum FinishReason {
|
||||
#[schema(rename = "length")]
|
||||
Length,
|
||||
#[serde(rename = "eos_token")]
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
use opentelemetry::sdk::trace::Sampler;
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::{global, KeyValue};
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
||||
|
||||
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
|
||||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - otlp_service_name service name to appear in APM
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||
pub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_ansi(ansi)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
true => fmt_layer.json().flatten_event(true).boxed(),
|
||||
false => fmt_layer.boxed(),
|
||||
};
|
||||
layers.push(fmt_layer);
|
||||
|
||||
// OpenTelemetry tracing layer
|
||||
if let Some(otlp_endpoint) = otlp_endpoint {
|
||||
global::set_text_map_propagator(TraceContextPropagator::new());
|
||||
|
||||
let tracer = opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(
|
||||
opentelemetry_otlp::new_exporter()
|
||||
.tonic()
|
||||
.with_endpoint(otlp_endpoint),
|
||||
)
|
||||
.with_trace_config(
|
||||
trace::config()
|
||||
.with_resource(Resource::new(vec![KeyValue::new(
|
||||
"service.name",
|
||||
otlp_service_name,
|
||||
)]))
|
||||
.with_sampler(Sampler::AlwaysOn),
|
||||
)
|
||||
.install_batch(opentelemetry::runtime::Tokio);
|
||||
|
||||
if let Ok(tracer) = tracer {
|
||||
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
|
||||
init_tracing_opentelemetry::init_propagator().unwrap();
|
||||
};
|
||||
}
|
||||
|
||||
// Filter events with LOG_LEVEL
|
||||
let varname = "LOG_LEVEL";
|
||||
let env_filter = if let Ok(log_level) = std::env::var(varname) {
|
||||
// Override to avoid simple logs to be spammed with tokio level informations
|
||||
let log_level = match &log_level[..] {
|
||||
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
|
||||
"info" => "text_generation_launcher=info,text_generation_router=info",
|
||||
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
|
||||
log_level => log_level,
|
||||
};
|
||||
EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.parse_lossy(log_level)
|
||||
} else {
|
||||
EnvFilter::new("info")
|
||||
};
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(layers)
|
||||
.init();
|
||||
}
|
|
@ -1,14 +1,13 @@
|
|||
/// HTTP Server logic
|
||||
use crate::config::Config;
|
||||
use crate::infer::v2::SchedulerV2;
|
||||
use crate::infer::v3::SchedulerV3;
|
||||
use crate::infer::{HealthCheck, Scheduler};
|
||||
use crate::infer::{Infer, InferError, InferResponse, InferStreamResponse, ToolGrammar};
|
||||
use crate::infer::tool_grammar::ToolGrammar;
|
||||
use crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};
|
||||
#[cfg(feature = "kserve")]
|
||||
use crate::kserve::{
|
||||
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
|
||||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::usage_stats;
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
|
||||
|
@ -27,7 +26,7 @@ use crate::{
|
|||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
||||
use async_stream::__private::AsyncStream;
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
|
@ -37,15 +36,18 @@ use futures::stream::StreamExt;
|
|||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::Arc;
|
||||
use text_generation_client::{v2, v3, ClientError, ShardInfo};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
use tokenizers::processors::template::TemplateProcessing;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokio::select;
|
||||
use tokio::signal;
|
||||
|
@ -124,12 +126,10 @@ responses(
|
|||
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
|
||||
)
|
||||
)]
|
||||
#[instrument(skip(health))]
|
||||
#[instrument(skip(infer))]
|
||||
/// Health check method
|
||||
async fn health(
|
||||
mut health: Extension<HealthCheck>,
|
||||
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match health.check().await {
|
||||
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
|
||||
match infer.health().await {
|
||||
true => Ok(()),
|
||||
false => Err((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
|
@ -430,8 +430,9 @@ async fn generate_stream_internal(
|
|||
} else {
|
||||
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
|
||||
// Keep permit as long as generate_stream lives
|
||||
Ok((_permit, _input_length, mut response_stream)) => {
|
||||
Ok((_permit, _input_length, response_stream)) => {
|
||||
let mut index = 0;
|
||||
let mut response_stream = Box::pin(response_stream);
|
||||
// Server-Sent Event stream
|
||||
while let Some(response) = response_stream.next().await {
|
||||
index += 1;
|
||||
|
@ -1396,40 +1397,6 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
|||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ComputeType(String);
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
master_shard_uds_path: String,
|
||||
model_info: HubModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
tokenizer: Option<Tokenizer>,
|
||||
config: Option<Config>,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
api_key: Option<String>,
|
||||
ngrok: bool,
|
||||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
tokenizer_config: HubTokenizerConfig,
|
||||
preprocessor_config: Option<HubPreprocessorConfig>,
|
||||
processor_config: HubProcessorConfig,
|
||||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
print_schema_command: bool,
|
||||
) -> Result<(), WebServerError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
|
@ -1508,150 +1475,378 @@ pub async fn run(
|
|||
)
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
pub struct ApiDoc;
|
||||
|
||||
pub fn schema() -> ApiDoc {
|
||||
ApiDoc
|
||||
}
|
||||
|
||||
/// Serving method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run(
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
validation_workers: usize,
|
||||
api_key: Option<String>,
|
||||
tokenizer_name: String,
|
||||
tokenizer_config_path: Option<String>,
|
||||
revision: Option<String>,
|
||||
hostname: String,
|
||||
port: u16,
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
ngrok: bool,
|
||||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
messages_api_enabled: bool,
|
||||
disable_grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
disable_usage_stats: bool,
|
||||
disable_crash_reports: bool,
|
||||
) -> Result<(), WebServerError> {
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
// Finally, convert to AllowOrigin
|
||||
let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||
AllowOrigin::list(
|
||||
cors_allow_origin
|
||||
.iter()
|
||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||
)
|
||||
});
|
||||
|
||||
// Parse Huggingface hub token
|
||||
let authorization_token = std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.ok();
|
||||
|
||||
// Tokenizer instance
|
||||
// This will only be used to validate payloads
|
||||
let local_path = Path::new(&tokenizer_name);
|
||||
|
||||
// Shared API builder initialization
|
||||
let api_builder = || {
|
||||
let mut builder = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_token(authorization_token);
|
||||
|
||||
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
|
||||
builder = builder.with_cache_dir(cache_dir.into());
|
||||
}
|
||||
|
||||
builder
|
||||
};
|
||||
|
||||
// Decide if we need to use the API based on the revision and local path
|
||||
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
|
||||
|
||||
// Initialize API if needed
|
||||
#[derive(Clone)]
|
||||
enum Type {
|
||||
Api(Api),
|
||||
Cache(Cache),
|
||||
None,
|
||||
}
|
||||
let api = if use_api {
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||
.map_err(|_| ())
|
||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||
.unwrap_or_else(|_| Cache::default());
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
tracing::info!("Using the Hugging Face API");
|
||||
match api_builder().build() {
|
||||
Ok(api) => Type::Api(api),
|
||||
Err(_) => {
|
||||
tracing::warn!("Unable to build the Hugging Face API");
|
||||
Type::None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Type::None
|
||||
};
|
||||
|
||||
// Load tokenizer and model info
|
||||
let (
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
) = match api {
|
||||
Type::None => (
|
||||
Some(local_path.join("tokenizer.json")),
|
||||
Some(local_path.join("config.json")),
|
||||
Some(local_path.join("tokenizer_config.json")),
|
||||
Some(local_path.join("preprocessor_config.json")),
|
||||
Some(local_path.join("processor_config.json")),
|
||||
None,
|
||||
),
|
||||
Type::Api(api) => {
|
||||
let api_repo = api.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
|
||||
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
|
||||
Ok(tokenizer_filename) => Some(tokenizer_filename),
|
||||
Err(_) => get_base_tokenizer(&api, &api_repo).await,
|
||||
};
|
||||
let config_filename = api_repo.get("config.json").await.ok();
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
|
||||
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
|
||||
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
|
||||
|
||||
let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {
|
||||
Some(model_info)
|
||||
} else {
|
||||
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
|
||||
None
|
||||
};
|
||||
(
|
||||
tokenizer_filename,
|
||||
config_filename,
|
||||
tokenizer_config_filename,
|
||||
preprocessor_config_filename,
|
||||
processor_config_filename,
|
||||
model_info,
|
||||
)
|
||||
}
|
||||
Type::Cache(cache) => {
|
||||
let repo = cache.repo(Repo::with_revision(
|
||||
tokenizer_name.to_string(),
|
||||
RepoType::Model,
|
||||
revision.clone().unwrap_or_else(|| "main".to_string()),
|
||||
));
|
||||
(
|
||||
repo.get("tokenizer.json"),
|
||||
repo.get("config.json"),
|
||||
repo.get("tokenizer_config.json"),
|
||||
repo.get("preprocessor_config.json"),
|
||||
repo.get("processor_config.json"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
|
||||
{
|
||||
HubTokenizerConfig::from_file(filename)
|
||||
} else {
|
||||
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
|
||||
};
|
||||
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
|
||||
tracing::warn!("Could not find tokenizer config locally and no API specified");
|
||||
HubTokenizerConfig::default()
|
||||
});
|
||||
|
||||
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
|
||||
let mut tokenizer = Tokenizer::from_file(filename).ok();
|
||||
if let Some(tokenizer) = &mut tokenizer {
|
||||
if let Some(class) = &tokenizer_config.tokenizer_class {
|
||||
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
|
||||
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
|
||||
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
|
||||
tokenizer.with_post_processor(post_processor);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenizer
|
||||
});
|
||||
|
||||
let config: Option<Config> = config_filename.and_then(|filename| {
|
||||
std::fs::read_to_string(filename)
|
||||
.ok()
|
||||
.as_ref()
|
||||
.and_then(|c| {
|
||||
let config: Result<Config, _> = serde_json::from_str(c);
|
||||
if let Err(err) = &config {
|
||||
tracing::warn!("Could not parse config {err:?}");
|
||||
}
|
||||
config.ok()
|
||||
})
|
||||
});
|
||||
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
|
||||
model_id: tokenizer_name.to_string(),
|
||||
sha: None,
|
||||
pipeline_tag: None,
|
||||
});
|
||||
|
||||
let processor_config = processor_config_filename
|
||||
.and_then(HubProcessorConfig::from_file)
|
||||
.unwrap_or_default();
|
||||
|
||||
let preprocessor_config: Option<HubPreprocessorConfig> =
|
||||
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
|
||||
|
||||
tracing::info!("Using config {config:?}");
|
||||
if tokenizer.is_none() {
|
||||
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
|
||||
tracing::warn!("Rust input length validation and truncation is disabled");
|
||||
}
|
||||
|
||||
// Only send usage stats when TGI is run in container and the function returns Some
|
||||
let is_container = matches!(usage_stats::is_container(), Ok(true));
|
||||
|
||||
let user_agent = if !disable_usage_stats && is_container {
|
||||
let reduced_args = usage_stats::Args::new(
|
||||
config.clone(),
|
||||
tokenizer_config.tokenizer_class.clone(),
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
// waiting_served_ratio,
|
||||
// max_batch_prefill_tokens,
|
||||
// max_batch_total_tokens,
|
||||
// max_waiting_tokens,
|
||||
// max_batch_size,
|
||||
revision.clone(),
|
||||
validation_workers,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
disable_usage_stats,
|
||||
disable_crash_reports,
|
||||
);
|
||||
Some(usage_stats::UserAgent::new(reduced_args))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(ref ua) = user_agent {
|
||||
let start_event =
|
||||
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
|
||||
tokio::spawn(async move {
|
||||
start_event.send().await;
|
||||
});
|
||||
};
|
||||
let compat_return_full_text = match &model_info.pipeline_tag {
|
||||
None => {
|
||||
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
|
||||
true
|
||||
}
|
||||
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
|
||||
};
|
||||
let result = start(
|
||||
backend,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
validation_workers,
|
||||
api_key,
|
||||
config,
|
||||
(tokenizer, tokenizer_config),
|
||||
(preprocessor_config, processor_config),
|
||||
hostname,
|
||||
port,
|
||||
ngrok,
|
||||
_ngrok_authtoken,
|
||||
_ngrok_edge,
|
||||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
model_info,
|
||||
compat_return_full_text,
|
||||
allow_origin,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(ua) = user_agent {
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let stop_event = usage_stats::UsageStatsEvent::new(
|
||||
ua.clone(),
|
||||
usage_stats::EventType::Stop,
|
||||
None,
|
||||
);
|
||||
stop_event.send().await;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if !disable_crash_reports {
|
||||
let error_event = usage_stats::UsageStatsEvent::new(
|
||||
ua.clone(),
|
||||
usage_stats::EventType::Error,
|
||||
Some(e.to_string()),
|
||||
);
|
||||
error_event.send().await;
|
||||
} else {
|
||||
let unknow_error_event = usage_stats::UsageStatsEvent::new(
|
||||
ua.clone(),
|
||||
usage_stats::EventType::Error,
|
||||
Some("unknow_error".to_string()),
|
||||
);
|
||||
unknow_error_event.send().await;
|
||||
}
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn start(
|
||||
backend: impl Backend + Send + Sync + 'static,
|
||||
max_concurrent_requests: usize,
|
||||
max_best_of: usize,
|
||||
max_stop_sequences: usize,
|
||||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
validation_workers: usize,
|
||||
api_key: Option<String>,
|
||||
config: Option<Config>,
|
||||
(tokenizer, tokenizer_config): (Option<Tokenizer>, HubTokenizerConfig),
|
||||
(preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),
|
||||
hostname: String,
|
||||
port: u16,
|
||||
ngrok: bool,
|
||||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
messages_api_enabled: bool,
|
||||
disable_grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
model_info: HubModelInfo,
|
||||
compat_return_full_text: bool,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
) -> Result<(), WebServerError> {
|
||||
// Determine the server port based on the feature and environment variable.
|
||||
let port = if cfg!(feature = "google") {
|
||||
std::env::var("AIP_HTTP_PORT")
|
||||
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
|
||||
.unwrap_or(port)
|
||||
} else {
|
||||
port
|
||||
};
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
Ok(ip) => SocketAddr::new(ip, port),
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
|
||||
}
|
||||
};
|
||||
|
||||
// Create state
|
||||
if print_schema_command {
|
||||
let api_doc = ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||
Arc<dyn Scheduler + Send + Sync>,
|
||||
HealthCheck,
|
||||
ShardInfo,
|
||||
u32,
|
||||
) = {
|
||||
// Helper function to check both v2 and v3
|
||||
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
|
||||
match max_supported_batch_total_tokens {
|
||||
// Older models do not support automatic max-batch-total-tokens
|
||||
None => {
|
||||
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
|
||||
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
|
||||
);
|
||||
tracing::warn!("Model does not support automatic max batch total tokens");
|
||||
Ok(max_batch_total_tokens)
|
||||
}
|
||||
// Flash attention models return their max supported total tokens
|
||||
Some(max_supported_batch_total_tokens) => {
|
||||
// Warn if user added his own max-batch-total-tokens as we will ignore it
|
||||
if max_batch_total_tokens.is_some() {
|
||||
tracing::warn!(
|
||||
"`--max-batch-total-tokens` is deprecated for Flash \
|
||||
Attention models."
|
||||
);
|
||||
tracing::warn!(
|
||||
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
|
||||
);
|
||||
}
|
||||
if max_total_tokens as u32 > max_supported_batch_total_tokens {
|
||||
return Err(WebServerError::NotEnoughMemory(max_total_tokens));
|
||||
}
|
||||
|
||||
Ok(max_supported_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let generation_health = Arc::new(AtomicBool::new(false));
|
||||
|
||||
match v3::ShardedClient::connect_uds(master_shard_uds_path.clone()).await {
|
||||
Ok(mut sharded_client) => {
|
||||
// server is running on v3
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(WebServerError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(WebServerError::Warmup)?,
|
||||
)?;
|
||||
|
||||
let health_ext =
|
||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||
let scheduler = Arc::new(SchedulerV3::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
));
|
||||
tracing::info!("Using scheduler V3");
|
||||
|
||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||
}
|
||||
Err(_) => {
|
||||
let mut sharded_client = v2::ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.map_err(WebServerError::Connection)?;
|
||||
|
||||
// server is running on v2
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.map_err(WebServerError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client.info().await.map_err(WebServerError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
let max_batch_total_tokens = check_max_batch_total_tokens(
|
||||
sharded_client
|
||||
.warmup(
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
)
|
||||
.await
|
||||
.map_err(WebServerError::Warmup)?,
|
||||
)?;
|
||||
|
||||
let health_ext =
|
||||
HealthCheck::new(Arc::new(sharded_client.clone()), generation_health.clone());
|
||||
let scheduler = Arc::new(SchedulerV2::new(
|
||||
sharded_client,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
generation_health,
|
||||
));
|
||||
tracing::info!("Using scheduler V2");
|
||||
|
||||
(scheduler, health_ext, shard_info, max_batch_total_tokens)
|
||||
}
|
||||
}
|
||||
};
|
||||
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
|
||||
|
||||
let validation = Validation::new(
|
||||
validation_workers,
|
||||
tokenizer,
|
||||
|
@ -1662,11 +1857,11 @@ pub async fn run(
|
|||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
grammar_support,
|
||||
disable_grammar_support,
|
||||
);
|
||||
|
||||
let infer = Infer::new(
|
||||
scheduler,
|
||||
backend,
|
||||
validation,
|
||||
max_concurrent_requests,
|
||||
tokenizer_config,
|
||||
|
@ -1703,8 +1898,8 @@ pub async fn run(
|
|||
let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
|
||||
let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
|
||||
// Speculated tokens buckets
|
||||
let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||
// let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
|
||||
// let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
|
||||
|
||||
// Prometheus handler
|
||||
let builder = PrometheusBuilder::new()
|
||||
|
@ -1717,9 +1912,9 @@ pub async fn run(
|
|||
.set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
|
||||
.unwrap()
|
||||
.set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
|
||||
.unwrap()
|
||||
.set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||
.unwrap();
|
||||
// .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
|
||||
// .unwrap();
|
||||
let prom_handle = builder
|
||||
.install_recorder()
|
||||
.expect("failed to install metrics recorder");
|
||||
|
@ -1735,18 +1930,18 @@ pub async fn run(
|
|||
let info = Info {
|
||||
model_id: model_info.model_id,
|
||||
model_sha: model_info.sha,
|
||||
model_dtype: shard_info.dtype,
|
||||
model_device_type: shard_info.device_type,
|
||||
// model_dtype: shard_info.dtype,
|
||||
// model_device_type: shard_info.device_type,
|
||||
model_pipeline_tag: model_info.pipeline_tag,
|
||||
max_concurrent_requests,
|
||||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
// waiting_served_ratio,
|
||||
// max_batch_total_tokens,
|
||||
// max_waiting_tokens,
|
||||
// max_batch_size,
|
||||
validation_workers,
|
||||
max_client_batch_size,
|
||||
router: env!("CARGO_PKG_NAME"),
|
||||
|
@ -1907,7 +2102,6 @@ pub async fn run(
|
|||
// add layers after routes
|
||||
app = app
|
||||
.layer(Extension(info))
|
||||
.layer(Extension(health_ext.clone()))
|
||||
.layer(Extension(compat_return_full_text))
|
||||
.layer(Extension(infer))
|
||||
.layer(Extension(compute_type))
|
||||
|
@ -1945,6 +2139,68 @@ pub async fn run(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// get model info from the Huggingface Hub
|
||||
pub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
|
||||
let response = api.info_request().send().await.ok()?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let hub_model_info: HubModelInfo =
|
||||
serde_json::from_str(&response.text().await.ok()?).ok()?;
|
||||
if let Some(sha) = &hub_model_info.sha {
|
||||
tracing::info!(
|
||||
"Serving revision {sha} of model {}",
|
||||
hub_model_info.model_id
|
||||
);
|
||||
}
|
||||
Some(hub_model_info)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// get base tokenizer
|
||||
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
|
||||
let config_filename = api_repo.get("config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
let file = File::open(config_filename).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Read the JSON contents of the file as an instance of `User`.
|
||||
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
|
||||
|
||||
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
|
||||
let api_base_repo = api.repo(Repo::with_revision(
|
||||
base_model_id.to_string(),
|
||||
RepoType::Model,
|
||||
"main".to_string(),
|
||||
));
|
||||
|
||||
api_base_repo.get("tokenizer.json").await.ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// get tokenizer_config from the Huggingface Hub
|
||||
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
|
||||
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
|
||||
|
||||
// Open the file in read-only mode with buffer.
|
||||
let file = File::open(tokenizer_config_filename).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
|
||||
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
|
||||
.map_err(|e| {
|
||||
tracing::warn!("Unable to parse tokenizer config: {}", e);
|
||||
e
|
||||
})
|
||||
.ok()?;
|
||||
|
||||
Some(tokenizer_config)
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
|
@ -2008,16 +2264,77 @@ impl From<InferError> for Event {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WebServerError {
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Not enough memory to handle `max_total_tokens={0}`")]
|
||||
NotEnoughMemory(usize),
|
||||
#[error("Axum error: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
||||
/// Create a post_processor for the LlamaTokenizer
|
||||
fn create_post_processor(
|
||||
tokenizer: &Tokenizer,
|
||||
tokenizer_config: &HubTokenizerConfig,
|
||||
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
|
||||
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
|
||||
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
|
||||
|
||||
let bos_token = tokenizer_config.bos_token.as_ref();
|
||||
let eos_token = tokenizer_config.eos_token.as_ref();
|
||||
|
||||
if add_bos_token && bos_token.is_none() {
|
||||
panic!("add_bos_token = true but bos_token is None");
|
||||
}
|
||||
|
||||
if add_eos_token && eos_token.is_none() {
|
||||
panic!("add_eos_token = true but eos_token is None");
|
||||
}
|
||||
|
||||
let mut single = Vec::new();
|
||||
let mut pair = Vec::new();
|
||||
let mut special_tokens = Vec::new();
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
let bos_token_id = tokenizer
|
||||
.token_to_id(bos.as_str())
|
||||
.expect("Should have found the bos token id");
|
||||
special_tokens.push((bos.as_str(), bos_token_id));
|
||||
single.push(format!("{}:0", bos.as_str()));
|
||||
pair.push(format!("{}:0", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
single.push("$A:0".to_string());
|
||||
pair.push("$A:0".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
let eos_token_id = tokenizer
|
||||
.token_to_id(eos.as_str())
|
||||
.expect("Should have found the eos token id");
|
||||
special_tokens.push((eos.as_str(), eos_token_id));
|
||||
single.push(format!("{}:0", eos.as_str()));
|
||||
pair.push(format!("{}:0", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
if add_bos_token {
|
||||
if let Some(bos) = bos_token {
|
||||
pair.push(format!("{}:1", bos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
pair.push("$B:1".to_string());
|
||||
|
||||
if add_eos_token {
|
||||
if let Some(eos) = eos_token {
|
||||
pair.push(format!("{}:1", eos.as_str()));
|
||||
}
|
||||
}
|
||||
|
||||
let post_processor = TemplateProcessing::builder()
|
||||
.try_single(single)?
|
||||
.try_pair(pair)?
|
||||
.special_tokens(special_tokens)
|
||||
.build()?;
|
||||
|
||||
Ok(post_processor)
|
||||
}
|
||||
|
|
|
@ -78,11 +78,11 @@ pub struct Args {
|
|||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
// waiting_served_ratio: f32,
|
||||
// max_batch_prefill_tokens: u32,
|
||||
// max_batch_total_tokens: Option<u32>,
|
||||
// max_waiting_tokens: usize,
|
||||
// max_batch_size: Option<usize>,
|
||||
revision: Option<String>,
|
||||
validation_workers: usize,
|
||||
messages_api_enabled: bool,
|
||||
|
@ -103,11 +103,11 @@ impl Args {
|
|||
max_top_n_tokens: u32,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
waiting_served_ratio: f32,
|
||||
max_batch_prefill_tokens: u32,
|
||||
max_batch_total_tokens: Option<u32>,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
// waiting_served_ratio: f32,
|
||||
// max_batch_prefill_tokens: u32,
|
||||
// max_batch_total_tokens: Option<u32>,
|
||||
// max_waiting_tokens: usize,
|
||||
// max_batch_size: Option<usize>,
|
||||
revision: Option<String>,
|
||||
validation_workers: usize,
|
||||
messages_api_enabled: bool,
|
||||
|
@ -125,11 +125,11 @@ impl Args {
|
|||
max_top_n_tokens,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
// waiting_served_ratio,
|
||||
// max_batch_prefill_tokens,
|
||||
// max_batch_total_tokens,
|
||||
// max_waiting_tokens,
|
||||
// max_batch_size,
|
||||
revision,
|
||||
validation_workers,
|
||||
messages_api_enabled,
|
||||
|
|
|
@ -5,13 +5,12 @@ use crate::{
|
|||
GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,
|
||||
};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
use image::{io::Reader as ImageReader, ImageFormat};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
use jsonschema::{Draft, JSONSchema};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::io::Cursor;
|
||||
use std::iter;
|
||||
use text_generation_client::{Chunk, Image, InputChunk};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
use tokio::sync::mpsc;
|
||||
|
@ -96,7 +95,7 @@ impl Validation {
|
|||
&self,
|
||||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<InputChunk>)>, ValidationError> {
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some(sender) = &self.sender {
|
||||
// Create response channel
|
||||
|
@ -122,7 +121,7 @@ impl Validation {
|
|||
inputs: String,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<InputChunk>, usize, u32), ValidationError> {
|
||||
) -> Result<(Vec<Chunk>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
// Create response channel
|
||||
|
@ -181,11 +180,7 @@ impl Validation {
|
|||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
}
|
||||
|
||||
Ok((
|
||||
vec![Chunk::Text(inputs).into()],
|
||||
input_length,
|
||||
max_new_tokens,
|
||||
))
|
||||
Ok((vec![Chunk::Text(inputs)], input_length, max_new_tokens))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -589,7 +584,7 @@ fn prepare_input(
|
|||
tokenizer: &Tokenizer,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
) -> Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError> {
|
||||
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
|
||||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
|
@ -601,16 +596,16 @@ fn prepare_input(
|
|||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()).into());
|
||||
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }).into());
|
||||
input_chunks.push(Chunk::Image(Image { data, mimetype }));
|
||||
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() {
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()).into());
|
||||
input_chunks.push(Chunk::Text(inputs[start..].to_string()));
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
|
||||
|
@ -618,7 +613,7 @@ fn prepare_input(
|
|||
|
||||
(tokenizer_query, input_chunks)
|
||||
}
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs).into()]),
|
||||
_ => (inputs.clone(), vec![Chunk::Text(inputs)]),
|
||||
};
|
||||
|
||||
// Get the number of tokens in the input
|
||||
|
@ -631,18 +626,51 @@ fn prepare_input(
|
|||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<InputChunk>), ValidationError>>,
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub struct Image {
|
||||
pub data: Vec<u8>,
|
||||
pub mimetype: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum Chunk {
|
||||
Text(String),
|
||||
Image(Image),
|
||||
}
|
||||
|
||||
/// Convert input chunks to a stringly-typed input for backwards
|
||||
/// compat for backends that haven't implemented chunked inputs.
|
||||
pub trait ChunksToString {
|
||||
/// Convert chunks to string.
|
||||
fn chunks_to_string(&self) -> String;
|
||||
}
|
||||
|
||||
impl ChunksToString for Vec<Chunk> {
|
||||
fn chunks_to_string(&self) -> String {
|
||||
let mut output = String::new();
|
||||
self.iter().for_each(|c| match &c {
|
||||
Chunk::Text(text) => output.push_str(text),
|
||||
Chunk::Image(Image { data, mimetype }) => {
|
||||
let encoded = STANDARD.encode(data);
|
||||
output.push_str(&format!("![](data:{};base64,{})", mimetype, encoded))
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) enum ValidGrammar {
|
||||
pub enum ValidGrammar {
|
||||
Json(String),
|
||||
Regex(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidParameters {
|
||||
pub struct ValidParameters {
|
||||
/// / exponential scaling output probability distribution
|
||||
pub temperature: f32,
|
||||
/// / restricting to the k highest probability elements
|
||||
|
@ -666,7 +694,7 @@ pub(crate) struct ValidParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidStoppingParameters {
|
||||
pub struct ValidStoppingParameters {
|
||||
/// / Maximum number of generated tokens
|
||||
pub max_new_tokens: u32,
|
||||
/// / Optional stopping sequences
|
||||
|
@ -677,8 +705,8 @@ pub(crate) struct ValidStoppingParameters {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ValidGenerateRequest {
|
||||
pub inputs: Vec<InputChunk>,
|
||||
pub struct ValidGenerateRequest {
|
||||
pub inputs: Vec<Chunk>,
|
||||
pub input_length: u32,
|
||||
pub truncate: u32,
|
||||
pub decoder_input_details: bool,
|
||||
|
@ -750,6 +778,8 @@ pub enum ValidationError {
|
|||
InvalidImageContent(String),
|
||||
#[error("Could not fetch image: {0}")]
|
||||
FailedFetchImage(#[from] reqwest::Error),
|
||||
#[error("{0} modality is not supported")]
|
||||
UnsupportedModality(&'static str),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -167,22 +167,24 @@ def check_openapi(check: bool):
|
|||
else:
|
||||
os.rename(tmp_filename, filename)
|
||||
print("OpenAPI documentation updated.")
|
||||
errors = subprocess.run(
|
||||
p = subprocess.run(
|
||||
[
|
||||
"swagger-cli",
|
||||
"redocly",
|
||||
# allow for trailing whitespace since it's not significant
|
||||
# and the precommit hook will remove it
|
||||
"validate",
|
||||
"lint",
|
||||
filename,
|
||||
],
|
||||
capture_output=True,
|
||||
).stderr.decode("utf-8")
|
||||
)
|
||||
errors = p.stderr.decode("utf-8")
|
||||
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
|
||||
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
|
||||
if not errors.startswith("Swagger schema validation failed."):
|
||||
print(errors)
|
||||
if p.returncode != 0:
|
||||
print(errors)
|
||||
raise Exception(
|
||||
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}"
|
||||
f"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\n {errors}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
Loading…
Reference in New Issue