Merge branch 'main' into feature/get-trace-id-from-req-headers

This commit is contained in:
Hyeongchan Kim 2024-10-25 20:37:36 +09:00 committed by GitHub
commit 14e8ca5236
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
139 changed files with 5292 additions and 4308 deletions

68
Cargo.lock generated
View File

@ -2706,9 +2706,9 @@ dependencies = [
[[package]]
name = "opentelemetry"
version = "0.23.0"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b69a91d4893e713e06f724597ad630f1fa76057a5e1026c0ca67054a9032a76"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
dependencies = [
"futures-core",
"futures-sink",
@ -2819,19 +2819,17 @@ dependencies = [
[[package]]
name = "opentelemetry_sdk"
version = "0.23.0"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae312d58eaa90a82d2e627fd86e075cf5230b3f11794e2ed74199ebbe572d4fd"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
dependencies = [
"async-trait",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"lazy_static",
"once_cell",
"opentelemetry 0.23.0",
"ordered-float 4.3.0",
"opentelemetry 0.24.0",
"percent-encoding",
"rand",
"thiserror",
@ -4185,16 +4183,17 @@ dependencies = [
"cmake",
"cxx",
"cxx-build",
"hashbrown 0.14.5",
"hf-hub",
"log",
"parking_lot",
"pkg-config",
"text-generation-router",
"thiserror",
"tokenizers 0.19.1",
"tokenizers",
"tokio",
"tokio-stream",
"tracing",
"tracing-opentelemetry 0.24.0",
"tracing-opentelemetry 0.25.0",
"tracing-subscriber",
]
@ -4212,7 +4211,7 @@ dependencies = [
"tabled",
"text-generation-client",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tracing",
"tracing-subscriber",
@ -4292,7 +4291,7 @@ dependencies = [
"serde_json",
"sysinfo",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
@ -4341,7 +4340,7 @@ dependencies = [
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tonic 0.10.2",
@ -4392,7 +4391,7 @@ dependencies = [
"slotmap",
"text-generation-router",
"thiserror",
"tokenizers 0.20.0",
"tokenizers",
"tokio",
"tokio-stream",
"tonic 0.10.2",
@ -4514,39 +4513,6 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.12.1",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
name = "tokenizers"
version = "0.20.0"
@ -4933,14 +4899,14 @@ dependencies = [
[[package]]
name = "tracing-opentelemetry"
version = "0.24.0"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f68803492bf28ab40aeccaecc7021096bd256baf7ca77c3d425d89b35a7be4e4"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry 0.23.0",
"opentelemetry_sdk 0.23.0",
"opentelemetry 0.24.0",
"opentelemetry_sdk 0.24.1",
"smallvec",
"tracing",
"tracing-core",

View File

@ -1,23 +0,0 @@
# All the tooling for CUDA
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 AS cuda-builder
WORKDIR /usr/src/tgi/backends/trtllm
RUN apt update && apt install -y cmake git git-lfs gcc g++ ninja-build libopenmpi-dev python3-dev python3-pip wget
COPY . /usr/src/tgi
RUN chmod +x scripts/install_tensorrt.sh && scripts/install_tensorrt.sh
RUN cmake -G Ninja -B build -DTRT_LIB_DIR=/usr/local/tensorrt/lib -DTRT_INCLUDE_DIR=/usr/local/tensorrt/include .
RUN cmake --build build --parallel -t tgi_trtllm_backend_impl
# All the tooling for Rust
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
WORKDIR /usr/src
# Include CUDA related libraries and tools to the Rust based image
COPY --from=cuda-builder /usr/local/cuda /usr/local/cuda
COPY --from=cuda-builder /usr/local/tensorrt /usr/local/tensorrt
COPY --from=cuda-builder /usr/src/tgi/backends/trtllm/build /usr/local/tgi/trtllm/build
ENV PATH=/usr/local/cuda/bin:$PATH
ENV LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH
RUN apt update && apt install -y cmake git gcc g++ ninja-build libopenmpi3

View File

@ -327,7 +327,8 @@ ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh

View File

@ -218,7 +218,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -10,7 +10,7 @@ COPY . .
RUN cargo chef prepare --recipe-path recipe.json
# CUDA dependent dependencies resolver stage
FROM nvidia/cuda:12.5.1-cudnn-devel-ubuntu22.04 AS cuda-builder
FROM nvidia/cuda:12.6.1-cudnn-devel-ubuntu22.04 AS cuda-builder
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -26,6 +26,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
ninja-build \
pkg-config \
python3 \
python3-dev \
python3-setuptools \
tar \
wget
@ -42,7 +43,7 @@ RUN wget "https://download.open-mpi.org/release/open-mpi/v4.1/$OMPI_TARBALL_FILE
mkdir /usr/src/mpi && \
tar -xf "/opt/src/$OMPI_TARBALL_FILENAME" -C /usr/src/mpi --strip-components=1 && \
cd /usr/src/mpi && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda && \
./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \
make -j all && \
make install && \
rm -rf "/opt/src/$OMPI_TARBALL_FILENAME"
@ -82,10 +83,16 @@ RUN mkdir $TGI_INSTALL_PREFIX && mkdir "$TGI_INSTALL_PREFIX/include" && mkdir "$
cd backends/trtllm && \
CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX cargo build --release
FROM nvidia/cuda:12.5.1-cudnn-runtime-ubuntu22.04 AS runtime
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu22.04 AS runtime
RUN apt update && apt install -y python3-minimal python3-dev python3-pip && \
rm -rf /var/lib/{apt,dpkg,cache,log}/ && \
python3 -m pip install transformers tokenizers
WORKDIR /usr/local/tgi/bin
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV LD_LIBRARY_PATH="/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
ENV TOKENIZERS_PARALLELISM=false
ENV OMPI_MCA_plm_rsh_agent=""
COPY --from=mpi-builder /usr/local/mpi /usr/local/mpi
COPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt

View File

@ -98,7 +98,7 @@ curl 127.0.0.1:8080/generate_stream \
You can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.
```bash
curl localhost:3000/v1/chat/completions \
curl localhost:8080/v1/chat/completions \
-X POST \
-d '{
"model": "tgi",

View File

@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
@ -217,8 +218,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,

View File

@ -134,11 +134,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -245,7 +246,8 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
cache_len: 0,
chunk_len: None,
adapter_id: None,
};
let batch = Batch {
@ -255,7 +257,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}

View File

@ -1,5 +1,17 @@
cmake_minimum_required(VERSION 3.20)
if (NOT DEFINED CMAKE_CXX_COMPILER_LAUNCHER AND CMAKE_BUILD_TYPE STREQUAL "Debug")
find_program(CCACHE_EXECUTABLE "ccache")
if (CCACHE_EXECUTABLE)
message(STATUS "Using ccache")
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_EXECUTABLE}" CACHE PATH "Path to ccache" FORCE)
endif ()
endif ()
if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif ()
project(tgi-trtllm-backend VERSION 1.0.0)
set(CMAKE_CXX_STANDARD 20)
@ -14,7 +26,7 @@ set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include"
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features
find_package(CUDAToolkit 12.5 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
find_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)
#### External dependencies ####
include(cmake/fmt.cmake)

View File

@ -10,16 +10,17 @@ async-trait = "0.1"
async-stream = "0.3"
clap = { version = "4.5", features = ["derive"] }
cxx = "1.0"
hashbrown = "0.14"
hf-hub = { workspace = true }
log = { version = "0.4", features = [] }
text-generation-router = { path = "../../router" }
tokenizers = { version = "0.19", features = ["hf-hub"] }
tokio = { version = "1.38", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokenizers = { workspace = true }
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.15"
thiserror = "1.0.62"
thiserror = "1.0.63"
tracing = "0.1"
tracing-opentelemetry = "0.24"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
parking_lot = "0.12"
[build-dependencies]
cmake = "0.1"

View File

@ -6,7 +6,7 @@ use std::path::{absolute, PathBuf};
const ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 2] = ["spdlog", "fmt"];
const CUDA_ARCH_LIST: Option<&str> = option_env!("CUDA_ARCH_LIST");
const CUDA_REQUIRED_VERSION: &str = "12.5";
const CUDA_REQUIRED_VERSION: &str = "12.6";
const MPI_REQUIRED_VERSION: &str = "4.1";
const INSTALL_PREFIX: Option<&str> = option_env!("CMAKE_INSTALL_PREFIX");
const TENSORRT_ROOT_DIR: Option<&str> = option_env!("TENSORRT_ROOT_DIR");
@ -36,7 +36,7 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
// Build the backend implementation through CMake
let install_path = INSTALL_PREFIX.unwrap_or("/usr/local/tgi");
let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or("/usr/local/tensorrt");
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("90-real"); // Hopper by default
let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or("75-real;80-real;86-real;89-real;90-real");
let mut install_path = PathBuf::from(install_path);
if !install_path.is_absolute() {
@ -81,7 +81,12 @@ fn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf
(PathBuf::from(install_path), deps_folder)
}
fn build_ffi_layer(deps_folder: &PathBuf) {
fn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {
let ndebug = match is_debug {
true => "1",
false => "0",
};
CFG.include_prefix = "backends/trtllm";
cxx_build::bridge("src/lib.rs")
.static_flag(true)
@ -93,9 +98,14 @@ fn build_ffi_layer(deps_folder: &PathBuf) {
.include("/usr/local/tensorrt/include")
.file("src/ffi.cpp")
.std("c++20")
.define("NDEBUG", ndebug)
.compile("tgi_trtllm_backend");
println!("cargo:rerun-if-changed=CMakeLists.txt");
println!("cargo:rerun-if-changed=cmake/trtllm.cmake");
println!("cargo:rerun-if-changed=cmake/json.cmake");
println!("cargo:rerun-if-changed=cmake/fmt.cmake");
println!("cargo:rerun-if-changed=cmake/spdlog.cmake");
println!("cargo:rerun-if-changed=include/backend.h");
println!("cargo:rerun-if-changed=lib/backend.cpp");
println!("cargo:rerun-if-changed=include/ffi.h");
@ -115,7 +125,7 @@ fn main() {
let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);
// Build the FFI layer calling the backend above
build_ffi_layer(&deps_folder);
build_ffi_layer(&deps_folder, is_debug);
// Emit linkage search path
probe!("ompi", MPI_REQUIRED_VERSION);

View File

@ -1,6 +1,6 @@
FetchContent_Declare(
fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 11.0.1
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/fmtlib/fmt/archive/refs/tags/11.0.2.tar.gz
)
FetchContent_MakeAvailable(fmt)

View File

@ -1,5 +1,6 @@
fetchcontent_declare(
json
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz
)
fetchcontent_makeavailable(json)

View File

@ -11,7 +11,7 @@ endif ()
fetchcontent_declare(
spdlog
GIT_REPOSITORY https://github.com/gabime/spdlog.git
GIT_TAG v1.14.1
DOWNLOAD_EXTRACT_TIMESTAMP
URL https://github.com/gabime/spdlog/archive/refs/tags/v1.14.1.tar.gz
)
fetchcontent_makeavailable(spdlog)

View File

@ -23,8 +23,9 @@ endif ()
fetchcontent_declare(
trtllm
GIT_REPOSITORY https://github.com/NVIDIA/TensorRT-LLM.git
GIT_TAG a681853d3803ee5893307e812530b5e7004bb6e1
GIT_TAG 201135e58aa525af7e523d091d4c9584229524bc
GIT_SHALLOW FALSE
DOWNLOAD_EXTRACT_TIMESTAMP
)
fetchcontent_makeavailable(trtllm)

View File

@ -5,6 +5,7 @@
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H
#include <array>
#include <cmath>
#include <filesystem>
#include <span>
@ -19,16 +20,33 @@
using json = nlohmann::json;
namespace tle = tensorrt_llm::executor;
#define CAST_SIZETYPE(x) static_cast<tle::SizeType32>(x)
namespace huggingface::tgi::backends {
using RequestId = tle::IdType;
using TokenId = tle::TokenIdType;
const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_NOT_ENOUGH_GPUS = FMT_STRING(
"Not enough GPUs to allocate requested model (detected: {:d}, required: {:d})");
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");
/**
* Initialize all the components required by TRTLLM.
* It is required to call this function before attempting to load any engine
*/
void InitializeBackend();
/**
* Initialize logging mechanism
*/
void InitializeLogging();
/**
*
* @param config TensorRT-LLM configuration object
@ -37,6 +55,14 @@ namespace huggingface::tgi::backends {
*/
tle::ExecutorConfig GetExecutorConfig(const json &config, const std::string &workerPath);
/**
*
* @param worldSize
* @param workerPath
* @return
*/
tle::ParallelConfig GetParallelConfig(size_t worldSize, std::string workerPath) noexcept;
/**
* Get the sampling configuration from the parameters provided by TGI
* @param topK
@ -54,7 +80,15 @@ namespace huggingface::tgi::backends {
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
);
) noexcept;
/**
* Attempt to retrieve the
* @param generationConfigPath
* @return
*/
std::optional<std::list<std::vector<TokenId>>>
GetStopWordsFromConfig(const std::filesystem::path &generationConfigPath) noexcept;
/**
*
@ -64,18 +98,16 @@ namespace huggingface::tgi::backends {
const json config;
tle::Executor executor;
/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
std::list<std::vector<TokenId>> stopWords;
public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
const std::filesystem::path &executorWorker
);
/**
* Indicate if the backend is ready to accept incoming request
* @return true if ready, false otherwise
*/
[[nodiscard]] bool IsReady() const;
/**
* Query the executor for the number of token available for pulling
* @return
@ -88,32 +120,23 @@ namespace huggingface::tgi::backends {
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param repetitionPenalty
* @param frequencyPenalty
* @param seed
* @return Request id related to this generation for reference
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
uint32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
float_t repetitionPenalty,
float_t frequencyPenalty,
uint64_t seed
);
/**
*
* @param requestId The request id to poll the generation results
* @return
*/
std::vector<tle::Response> Poll(RequestId requestId);
/**
* Stop the underlying executor
*/
void Shutdown();
[[nodiscard]] std::vector<tle::Response> PullNewTokens();
};
}

View File

@ -5,20 +5,31 @@
#ifndef TGI_TRTLLM_BACKEND_FFI_H
#define TGI_TRTLLM_BACKEND_FFI_H
#include <cmath>
#include <cstddef>
#include <memory>
#include "backend.h"
namespace huggingface::tgi::backends {
class TensorRtLlmBackendImpl;
}
// Template to support returning error from TllmException back to Rust in a Result<>
#include <tensorrt_llm/common/tllmException.h>
namespace rust::behavior {
template<typename Try, typename Fail>
static void trycatch(Try &&func, Fail &&fail) noexcept try {
func();
} catch (tensorrt_llm::common::TllmException &e) {
fail(e.what());
}
}
#include "backends/trtllm/src/lib.rs.h"
namespace huggingface::tgi::backends {
// struct GenerationContext;
class TensorRtLlmBackendImpl : public TensorRtLlmBackend {
public:
/***
@ -28,15 +39,10 @@ namespace huggingface::tgi::backends {
*/
TensorRtLlmBackendImpl(const std::string_view &engineFolder, const std::string_view &executorWorker);
/***
*
* @return
*/
bool IsReady() const;
/***
*
* @param tokens
* @param maxNewTokens
* @param topK
* @param topP
* @param temperature
@ -47,21 +53,15 @@ namespace huggingface::tgi::backends {
*/
[[nodiscard("returned request id should be used to refer to the request's generation result later on")]]
uint64_t
Submit(rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature,
Submit(rust::Slice<const uint32_t> tokens, uint32_t maxNewTokens,
int32_t topK, float_t topP, float_t temperature,
float_t repetition_penalty, float_t frequency_penalty, uint64_t seed);
/***
*
* @param requestId
* @param ctx
* @param callback
* @return
*/
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
std::unique_ptr<std::vector<GenerationStep>> PullTokens();
};
/***

View File

@ -14,7 +14,7 @@
namespace huggingface::hardware::cuda {
#define AMPERE_SM_MAJOR 8
#define HOPPER_SM_MAJOR 8
#define HOPPER_SM_MAJOR 9
/**
* Store information about the version of the CUDA Compute Capabilities detected on the device
@ -23,9 +23,9 @@ namespace huggingface::hardware::cuda {
int32_t major;
int32_t minor;
[[nodiscard]] constexpr bool isPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
[[nodiscard]] constexpr bool IsPostAmpere() const { return major >= AMPERE_SM_MAJOR; }
[[nodiscard]] constexpr bool isPostHopper() const { return major >= HOPPER_SM_MAJOR; }
[[nodiscard]] constexpr bool IsPostHopper() const { return major >= HOPPER_SM_MAJOR; }
};
CudaComputeCapabilities GetCudaComputeCapabilities() {

View File

@ -1,3 +1,4 @@
#include <cstdlib>
#include <fstream>
#include <fmt/ranges.h>
@ -7,11 +8,33 @@
#include "backend.h"
#include "hardware.h"
void huggingface::tgi::backends::InitializeLogging() {
#ifdef NDEBUG
if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv("TRTLLM_LOG_LEVEL")) {
std::string log_level(TRTLLM_LOG_LEVEL_CSTR);
std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {
return std::tolower(c);
});
if (log_level == "debug")
spdlog::set_level(spdlog::level::debug);
else
spdlog::set_level(spdlog::level::info);
}
#else
spdlog::set_level(spdlog::level::debug);
#endif
}
void huggingface::tgi::backends::InitializeBackend() {
SPDLOG_INFO("Initializing Backend...");
nvmlInit_v2();
initTrtLlmPlugins();
InitializeLogging();
SPDLOG_INFO("Backend Executor Version: {}", tle::version());
const auto numGpus = huggingface::hardware::cuda::GetNumDevices();
if (numGpus.has_value()) {
SPDLOG_INFO("Detected {:d} Nvidia GPU(s)", numGpus.value());
@ -20,47 +43,49 @@ void huggingface::tgi::backends::InitializeBackend() {
}
}
[[nodiscard]]
tle::ParallelConfig
huggingface::tgi::backends::GetParallelConfig(const size_t worldSize, const std::string workerPath) noexcept {
auto mode = tle::CommunicationMode::kLEADER;
std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;
if (worldSize > 1) {
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
mode = tle::CommunicationMode::kORCHESTRATOR;
orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, workerPath, nullptr, true);
} else {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
}
return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);
}
[[nodiscard]]
tle::ExecutorConfig huggingface::tgi::backends::GetExecutorConfig(const json &config, const std::string &workerPath) {
tle::ExecutorConfig execConfig(1);
tle::ExecutorConfig execConfig(/* maxBeamWidth = */ 1);
// Retrieve the compute capabilities to enable some options at runtime
const auto computeCapabilities = huggingface::hardware::cuda::GetCudaComputeCapabilities();
// Single engine (TP = PP = 1) -> using leader mode (no MPI involved)
if (config["/pretrained_config/mapping/world_size"_json_pointer].get<uint8_t>() == 1) {
SPDLOG_INFO("Detected single engine deployment, using leader mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kLEADER,
std::nullopt,
std::nullopt,
std::nullopt
));
} else { // Multiple engines -> using orchestrator mode (MPI involved)
SPDLOG_INFO("Detected sharded engine deployment, using orchestrator mode");
execConfig.setParallelConfig(tle::ParallelConfig(
tle::CommunicationType::kMPI,
tle::CommunicationMode::kORCHESTRATOR,
std::nullopt,
std::nullopt,
tle::OrchestratorConfig(true, workerPath, nullptr, true)
));
}
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
execConfig.setParallelConfig(GetParallelConfig(worldSize, workerPath));
// Define some configuration variables
execConfig.setKvCacheConfig(tle::KvCacheConfig(true));
execConfig.setEnableChunkedContext(computeCapabilities.isPostAmpere());
execConfig.setEnableChunkedContext(computeCapabilities.IsPostAmpere());
execConfig.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
return execConfig;
}
tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
const uint32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed) noexcept {
return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
@ -78,69 +103,101 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
);
}
std::optional<std::list<std::vector<huggingface::tgi::backends::TokenId>>>
huggingface::tgi::backends::GetStopWordsFromConfig(
const std::filesystem::path &generationConfigPath) noexcept {
if (exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_id"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
std::list<std::vector<huggingface::tgi::backends::TokenId>> stopWords(eosTokenIds.size());
const auto to_single_token = [](const auto tokenIdObj) -> decltype(stopWords)::value_type {
return {tokenIdObj.template get<tle::TokenIdType>()};
};
std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(), to_single_token);
return stopWords;
} else {
SPDLOG_INFO("Invalid EOS tokens entry found (not an array)");
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
}
return std::nullopt;
}
huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
const std::filesystem::path &enginesFolder,
const std::filesystem::path &executorWorker
) :
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(
enginesFolder,
tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string()
)) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());
}
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
bool huggingface::tgi::backends::TensorRtLlmBackend::IsReady() const {
return executor.canEnqueueRequests();
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());
// Ensure we have enough GPUs on the system
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
if (numGpus < worldSize) {
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
// todo : raise exception to catch on rust side
}
// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
// Attempt to discover stopWords from the generation_config.json
const auto generationConfigPath = enginesFolder / "generation_config.json";
stopWords = GetStopWordsFromConfig(generationConfigPath).value_or(std::list<std::vector<TokenId>>());
}
[[nodiscard("Returned number of requests needs to be consumed")]]
size_t huggingface::tgi::backends::TensorRtLlmBackend::NumResponsesReady() const {
#ifdef NDEBUG
return executor.getNumResponsesReady();
#else
const auto numResponses = executor.getNumResponsesReady();
if (numResponses > 0) SPDLOG_INFO(FMT_STRING("Num responses ready: {:d}"), numResponses);
return numResponses;
#endif
}
[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const float_t repetitionPenalty,
const float_t frequencyPenalty,
const uint64_t seed
) {
#ifdef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference over {:d} tokens to the executor ({:d} already in-flight)"),
tokens.size(),
executor.getLatestIterationStats().back().numActiveRequests
);
#else
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif
const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<size_t>();
const auto maxNewTokens = static_cast<int32_t>(std::max(1ul, maxNumTokens - tokens.size()));
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto output = tle::OutputConfig(true, false, false, true, false);
return executor.enqueueRequest(
tle::Request{tokens, maxNewTokens, true, sampling, output});
// Build the request
auto request = tle::Request{tokens, CAST_SIZETYPE(maxNewTokensChecked), true, sampling, OUTPUT_CONFIG};
request.setStopWords(stopWords);
// Submit to the executor for batching
return executor.enqueueRequest(request);
}
[[nodiscard("Generated tokens result must be used")]]
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::Poll(const tle::IdType requestId) {
SPDLOG_DEBUG(FMT_STRING("Polling status for request {:d}"), requestId);
return executor.awaitResponses(requestId);
}
void huggingface::tgi::backends::TensorRtLlmBackend::Shutdown() {
SPDLOG_INFO("Shutting down executor");
executor.shutdown();
std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
return executor.awaitResponses();
}

View File

@ -2,12 +2,13 @@
set -ex
TRT_VER="10.2.0.19"
CUDA_VER="12.5"
CUDNN_VER="9.2.1.18-1"
NCCL_VER="2.22.3-1+cuda12.5"
CUBLAS_VER="12.5.3.2-1"
NVRTC_VER="12.5.82-1"
TRT_VER_BASE="10.4.0"
TRT_VER_FULL="${TRT_VER_BASE}.26"
CUDA_VER="12.6"
CUDNN_VER="9.5.0.50-1"
NCCL_VER="2.22.3-1+cuda12.6"
CUBLAS_VER="12.6.3.3-1"
NVRTC_VER="12.6.77-1"
for i in "$@"; do
case $i in
@ -32,8 +33,9 @@ install_ubuntu_requirements() {
ARCH=$(uname -m)
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "aarch64" ];then ARCH="sbsa";fi
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/${ARCH}/cuda-keyring_1.0-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb
curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list
apt-get update
if [[ $(apt list --installed | grep libcudnn9) ]]; then
@ -71,7 +73,7 @@ install_centos_requirements() {
install_tensorrt() {
#PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
#PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
TRT_CUDA_VERSION="12.5"
TRT_CUDA_VERSION="12.6"
if [ -z "$RELEASE_URL_TRT" ];then
ARCH=${TRT_TARGETARCH}
@ -79,12 +81,12 @@ install_tensorrt() {
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
fi
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
tar -xf /tmp/TensorRT.tar -C /usr/local/
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt
# pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
}

View File

@ -1,330 +0,0 @@
use std::future::Future;
use std::path::Path;
use std::pin::{pin, Pin};
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use cxx::UniquePtr;
use log::{error, warn};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::{sleep, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::{instrument, span, Level};
// use tokio::sync::RwLock;
use parking_lot::RwLock;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidationError};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
type InferResult<T> = Result<T, InferError>;
pub(crate) struct Generation {
executor: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
done: Arc<AtomicBool>,
}
/// Holds the user provided input to be executed along with a channel allowing
/// to bubble up all the generated tokens for that tokens the to end stream.
pub struct GenerationContext {
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokenizer: Arc<Tokenizer>,
tokens: Vec<u32>,
done: Arc<AtomicBool>,
queued: Instant,
start: Option<Instant>,
}
impl Stream for Generation {
type Item = usize;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let interval = POLLING_INTERVAL_US.get_or_init(|| {
u64::from_str(option_env!("TRTLLM_BACKEND_POLLING_INTERVAL_US").unwrap_or("100"))
.expect("Invalid value provided for envvar POLLING_INTERVAL_US")
});
if !self.done.load(Ordering::Relaxed) {
let backend = pin!(self.executor.read());
let status = match backend.poll(ctx) {
Poll::Ready(executor_r) => {
let ready = executor_r.num_responses_ready();
if ready == 0 {
Poll::Pending
} else {
Poll::Ready(Some(ready))
}
}
Poll::Pending => Poll::Pending,
};
let waker = ctx.waker().clone();
tokio::spawn(async {
sleep(Duration::from_micros(*interval)).await;
waker.wake();
});
status
} else {
Poll::Ready(None) // end of stream
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
unsafe impl Send for TensorRtLlmBackendImpl {}
unsafe impl Sync for TensorRtLlmBackendImpl {}
/// Implements the logic to execute generation with TensorRT-LLM executor API in background
pub struct TensorRtLlmBackend {
tokenizer: Arc<Tokenizer>,
// Backing the backend behind a RwLock to allow concurrent read access to retrieve
// the number of available tokens (read only) in the Generation stream
backend: Arc<RwLock<UniquePtr<TensorRtLlmBackendImpl>>>,
}
impl TensorRtLlmBackend {
pub fn new<P: AsRef<Path> + Send + 'static, PP: AsRef<Path> + Send + 'static>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
) -> Result<Self, TensorRtLlmBackendError> {
Ok(TensorRtLlmBackend {
tokenizer: Arc::new(tokenizer),
backend: Arc::new(RwLock::new(create_tensorrt_llm_backend(
engine_folder.as_ref().to_str().unwrap(),
executor_worker_path.as_ref().to_str().unwrap(),
))),
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<&String> {
if request.top_n_tokens > 1 {
return Err(InferError::ValidationError(
ValidationError::TopNTokensDisabled,
));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(InferError::ValidationError(ValidationError::Grammar));
}
match request.inputs.len() {
0 => Err(InferError::ValidationError(ValidationError::EmptyInput)),
2.. => Err(InferError::GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(text),
Chunk::Image(_) => Err(InferError::ValidationError(UnsupportedModality("image"))),
},
}
}
fn generate(
&self,
sender: UnboundedSender<InferResult<InferStreamResponse>>,
tokens: Vec<u32>,
top_k: u32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) {
let tokenizer = Arc::clone(&self.tokenizer);
let executor = Arc::clone(&self.backend);
// Let's push this in async context
tokio::spawn(async move {
// Define the generation state
let mut generation = Generation {
executor: executor.clone(),
done: Arc::new(AtomicBool::new(false)),
};
// Define the context over the generation
// TODO(asap): Do we really need so many shared-ownership?
let ctx = Box::new(GenerationContext {
sender: sender.clone(),
tokenizer,
tokens: vec![],
done: Arc::clone(&generation.done),
start: None,
queued: Instant::now(),
});
// We are leaking the context on-purpose to avoid the box being dropped while there are
// still computation ongoing
// TODO(asap): Can we achieve the same with an Arc<Box<T>> without the need to go unsafe?
let ctx_ = Box::leak(ctx);
// Submit the request to the batcher
let request_id = span!(Level::DEBUG, "submit")
.in_scope(|| async {
let mut handle = executor.write().await;
let request_id = handle.pin_mut().submit(
&tokens,
top_k as i32,
top_p,
temperature,
repetition_penalty,
frequency_penalty,
seed,
);
request_id
})
.await;
while let Some(_) = generation.next().await {
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
error!("Error caught while decoding: {}", &step.error_msg);
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
}
})
.await;
}
// "Properly" free the shared context...
// TODO: clean that piece of sh** asap
unsafe {
let _ = Box::from_raw(ctx_);
}
});
}
}
#[async_trait]
impl Backend for TensorRtLlmBackend {
#[instrument(skip_all)]
fn schedule(
&self,
request: ValidGenerateRequest,
) -> InferResult<UnboundedReceiverStream<InferResult<InferStreamResponse>>> {
// Let's add a few more validation
let input = TensorRtLlmBackend::validate(&request)?;
// Channel to stream the generated token as they come from the worker thread back to the transport layer
let (sender, receiver) = unbounded_channel();
// Unpack parameters
let params = &request.parameters;
// Preprocess the inputs to send to TRTLLM backend
let encoding = self
.tokenizer
.encode(input.as_str(), true)
.map_err(|e| InferError::GenerationError(e.to_string()))?;
// Generate the response
self.generate(
sender,
Vec::from(encoding.get_ids()),
params.top_k,
params.top_p,
params.temperature,
params.repetition_penalty,
params.frequency_penalty,
params.seed,
);
Ok(UnboundedReceiverStream::new(receiver))
}
async fn health(&self, _current_health: bool) -> bool {
true
}
}

View File

@ -1,9 +1,16 @@
use std::path::PathBuf;
use thiserror::Error;
use text_generation_router::server;
#[derive(Debug, Error)]
pub enum TensorRtLlmBackendError {
#[error("Provided engine folder {0} doesn't exist")]
EngineFolderDoesntExists(PathBuf),
#[error("Provided executorWorker binary path {0} doesn't exist")]
ExecutorWorkerNotFound(PathBuf),
#[error("TensorRT-LLM Runtime error: {0}")]
Runtime(String),
#[error("Tokenizer error: {0}")]
Tokenizer(String),
#[error("Argument validation error: {0}")]

View File

@ -3,11 +3,13 @@
//
#pragma once
#include <cmath>
#include <algorithm>
#include <exception>
#include <filesystem>
#include <functional>
#include <limits>
#include <iterator>
#include <ranges>
#include <vector>
#include <spdlog/spdlog.h>
@ -20,61 +22,64 @@ huggingface::tgi::backends::TensorRtLlmBackendImpl::TensorRtLlmBackendImpl(
) : TensorRtLlmBackend(engineFolder, executorWorker) {}
bool huggingface::tgi::backends::TensorRtLlmBackendImpl::IsReady() const {
return TensorRtLlmBackend::IsReady();
}
uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
rust::Slice<const uint32_t> tokens, int32_t topK, float_t topP, float_t temperature, float_t repetition_penalty,
float_t frequency_penalty, uint64_t seed) {
rust::Slice<const uint32_t> tokens,
uint32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed) {
// This will copy all the items from the initial slice
std::vector<int32_t> tokens_(std::make_move_iterator(tokens.begin()), std::make_move_iterator(tokens.end()));
std::vector<int32_t> tokens_(tokens.begin(), tokens.end());
return TensorRtLlmBackend::Submit(
std::move(tokens_), topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
std::move(tokens_), maxNewTokens, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
}
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
std::unique_ptr<std::vector<huggingface::tgi::backends::GenerationStep>>
huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() {
const auto responses = TensorRtLlmBackend::PullNewTokens();
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
auto steps = std::make_unique<std::vector<GenerationStep>>();
steps->reserve(responses.size());
const auto token = decoded.outputTokenIds[0][0];
const auto isFinal = decoded.isFinal;
const auto logProb = decoded.logProbs.value()[0][0];
#ifndef NDEBUG
SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size());
#endif
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
// Transform tle::Response to GenerationStep
std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [](const tle::Response &r) {
const auto reqId = r.getRequestId();
if (!r.hasError()) {
const auto result = r.getResult();
return GenerationStep{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
result.logProbs.value()[0][0],
result.isFinal,
false,
std::string()
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
return GenerationStep{
reqId,
0,
0.0,
true,
true,
std::move(r.getErrorMsg())
};
}
});
callback(std::move(ctx), std::move(step));
}
return numTokens;
return steps;
}
std::unique_ptr<huggingface::tgi::backends::TensorRtLlmBackendImpl>
huggingface::tgi::backends::CreateTensorRtLlmBackend(rust::Str engineFolder, rust::Str executorWorker) {
SPDLOG_INFO("Creating TensorRT-LLM Backend");
// Unconditionally call this to initialize and discover TRTLLM plugins
InitializeBackend();

View File

@ -1,14 +1,16 @@
pub use backend::{GenerationContext, TensorRtLlmBackend};
pub use looper::TensorRtLlmBackendV2;
mod backend;
pub mod errors;
mod looper;
mod utils;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Debug, Clone)]
pub struct GenerationStep {
request_id: u64,
token_id: u32,
log_prob: f32,
is_final: bool,
@ -16,10 +18,6 @@ mod ffi {
error_msg: String,
}
extern "Rust" {
type GenerationContext;
}
unsafe extern "C++" {
include!("backends/trtllm/src/ffi.cpp");
@ -44,10 +42,7 @@ mod ffi {
fn CreateTensorRtLlmBackend(
engine_folder: &str,
executor_worker: &str,
) -> UniquePtr<TensorRtLlmBackendImpl>;
// #[rust_name = "is_ready"]
// fn IsReady(self: &TensorRtLlmBackendImpl) -> bool;
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
#[rust_name = "num_responses_ready"]
fn NumResponsesReady(self: &TensorRtLlmBackendImpl) -> usize;
@ -56,23 +51,18 @@ mod ffi {
fn Submit(
self: Pin<&mut TensorRtLlmBackendImpl>,
tokens: &[u32],
max_new_tokens: u32,
top_k: i32,
top_p: f32,
temperature: f32,
repetition_penalty: f32,
frequency_penalty: f32,
seed: u64,
) -> u64;
) -> Result<u64>;
#[rust_name = "stream_tokens"]
unsafe fn StreamTokens(
#[rust_name = "pull_tokens"]
fn PullTokens(
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]
// fn Shutdown(self: Pin<&mut TensorRtLlmBackendImpl>);
) -> Result<UniquePtr<CxxVector<GenerationStep>>>;
}
}

View File

@ -0,0 +1,382 @@
use std::hint;
use std::ops::Deref;
use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidationError::{
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
};
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
use text_generation_router::{FinishReason, Token};
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
use crate::utils::first_line;
type InferResult<T> = Result<T, InferError>;
/// Wrap the requests along with the channel used to stream back to the client the decoded tokens
struct GenerationContext {
request: ValidGenerateRequest,
start: Option<Instant>,
queued: Instant,
streamer: UnboundedSender<InferResult<InferStreamResponse>>,
}
#[derive(Debug, Copy, Clone)]
struct DecodedToken {
id: u32,
log_prob: f32,
is_final: bool,
}
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
})
} else {
Err(GenerationError(step.error_msg.clone()))
}
}
}
/// Wraps the decoded token with the channel used to stream back to the client the decoded tokens
struct DecodedTokenContext {
token: DecodedToken,
start: Option<Instant>,
queued: Instant,
channel: UnboundedSender<InferResult<InferStreamResponse>>,
}
fn executor_status_looper(
mut backend: UniquePtr<TensorRtLlmBackendImpl>,
max_inflight_requests: usize,
mut waiting_requests: UnboundedReceiver<GenerationContext>,
post_processor_sender: UnboundedSender<(u64, InferResult<DecodedTokenContext>)>,
) {
// Track the tuple (request_id, stream) for each request
let mut in_flights =
HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);
// TODO: Does it need a spin-loop?
'scheduler: loop {
// Is there any request pending to be scheduled?
let awaiting_requests = waiting_requests.len();
for _ in 0..awaiting_requests {
// Retrieve all the requests
if let Some(mut ctx) = waiting_requests.blocking_recv() {
// Submit all the request to the executor and move the context to the in-flight tracker
let request = &ctx.request;
let generation_params = &request.parameters;
let stopping_params = &request.stopping_parameters;
let input_ids = request.input_ids.as_deref();
// Submit to the TensorRT-LLM executor for scheduling
match backend.pin_mut().submit(
&input_ids.unwrap(), // This is checked beforehand in validate()
stopping_params.max_new_tokens,
generation_params.top_k as i32,
generation_params.top_p,
generation_params.temperature,
generation_params.repetition_penalty,
generation_params.frequency_penalty,
generation_params.seed,
) {
Ok(request_id) => {
// Insert the context linked to the generated request id in the tracker
debug!("[in-flight] Added {}", request_id);
ctx.start = Some(Instant::now());
in_flights.insert(request_id, ctx);
}
Err(e) => {
// Return to the caller
let what = e.to_string();
error!(error = what.as_str(), "Failed to schedule request");
let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));
if let Err(_) = ctx.streamer.send(err) {
error!("Failed to send back error to the client");
}
}
};
}
}
if backend.num_responses_ready() > 0 {
match backend.pin_mut().pull_tokens() {
Ok(responses) => {
// Iterate through all the decoded token
for step in responses.deref() {
if let Some(ctx) = in_flights.get(&step.request_id) {
// Remove from tracked requests
let parcel =
DecodedToken::try_from(step).map(|dt| DecodedTokenContext {
token: dt,
start: ctx.start,
queued: ctx.queued,
channel: ctx.streamer.clone(),
});
// Submit the work to p:the post_processor
let posted = post_processor_sender.send((step.request_id, parcel));
if posted.is_err() || step.is_final {
debug!("Removing {}", step.request_id);
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
}
}
}
Err(ref err) => {
error!("Failed to get responses from the executor: {}.", err.what());
break 'scheduler;
}
}
}
// Hint the CPU we are spin-locking
hint::spin_loop();
}
}
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
tokenizer: Tokenizer,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
) {
let mut states: HashMap<u64, Vec<u32>> = HashMap::with_capacity(max_inflight_requests * 2);
'post_processor: loop {
if decoded_tokens.is_closed() {
warn!("Post processor IPC is closed, loop will exit now.");
break 'post_processor;
}
if let Some((request_id, decoded)) = decoded_tokens.blocking_recv() {
match decoded {
Ok(ctx) => {
states
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
state.push(*&ctx.token.id);
state
});
let out = match tokenizer.decode(&[ctx.token.id], false) {
Ok(text) => {
let is_special =
tokenizer.get_added_vocabulary().is_special_token(&text);
let token = Token {
id: ctx.token.id,
text,
logprob: ctx.token.log_prob,
special: is_special,
};
let out = if !ctx.token.is_final {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
} else {
let tokens = states.remove(&request_id).unwrap();
let text = tokenizer.decode(&tokens, true);
let generated_text = GeneratedText {
text: text.unwrap(),
generated_tokens: tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
};
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text,
start: ctx.start.unwrap(),
queued: ctx.queued,
}
};
Ok(out)
}
Err(err) => Err(GenerationError(err.to_string())),
};
if let Err(_) = ctx.channel.send(out) {
warn!("Failed to send decoded token back to the user")
}
}
Err(_err) => {
todo!("what do we do?")
}
}
}
}
}
fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
engine_folder: P,
executor_worker_path: PP,
) -> Result<(String, String), TensorRtLlmBackendError> {
// Retrieve paths as &str for the backend creation
let engine_folder = engine_folder.as_ref();
let executor_worker_path = executor_worker_path.as_ref();
// Ensure the engine folder exists
if !engine_folder.exists() {
let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
// Ensure executor worker binary exists
if !executor_worker_path.exists() {
let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());
error!("Path validation failed: {}", err,);
return Err(err);
}
let engine_folder = String::from(
engine_folder
.to_str()
.expect("Failed to convert engine_folder to valid UTF-8"),
);
let executor_worker_path = String::from(
executor_worker_path
.to_str()
.expect("Failed to convert executor_worker_path to valid UTF-8"),
);
Ok((engine_folder, executor_worker_path))
}
unsafe impl Send for TensorRtLlmBackendImpl {}
pub struct TensorRtLlmBackendV2 {
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
}
impl TensorRtLlmBackendV2 {
pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(
tokenizer: Tokenizer,
engine_folder: P,
executor_worker_path: PP,
max_inflight_requests: usize,
) -> Result<Self, TensorRtLlmBackendError> {
let (engine_folder, executor_worker_path) =
ensure_paths_exist(engine_folder, executor_worker_path)?;
// Allocate the IPC layer to communicate with the backend
let (executor_sender, executor_receiver) = unbounded_channel();
let (post_processor_sender, post_processor_receiver) = unbounded_channel();
// Create the FFI backend
let backend = create_tensorrt_llm_backend(&engine_folder, &executor_worker_path)
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
// Executor looper is responsible for scheduling and pulling requests state at regular interval
let executor_looper = spawn_blocking(move || {
executor_status_looper(
backend,
max_inflight_requests,
executor_receiver,
post_processor_sender,
)
});
// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let post_processor_looper = spawn_blocking(move || {
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
});
Ok(TensorRtLlmBackendV2 {
executor_looper,
post_processor_looper,
executor: executor_sender,
})
}
fn validate(request: &ValidGenerateRequest) -> InferResult<()> {
if request.input_ids.is_none() {
return Err(ValidationError(UnsupportedModality("No token provided")));
}
if request.top_n_tokens > 1 {
return Err(ValidationError(TopNTokensDisabled));
}
// TODO: Is it really needed? How can it be validated before?
if request.parameters.grammar.is_some() {
return Err(ValidationError(Grammar));
}
match request.inputs.len() {
0 => Err(ValidationError(EmptyInput)),
2.. => Err(GenerationError(
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
}
}
#[async_trait]
impl Backend for TensorRtLlmBackendV2 {
fn schedule(
&self,
inner: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {
Self::validate(&inner)?;
// Open-up the stream to send tokens
let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();
// Send the context to the executor for scheduling
let queued = Instant::now();
match self.executor.send(GenerationContext {
request: inner,
start: None,
queued,
streamer,
}) {
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
Err(_) => Err(GenerationError(
"Failed to submit request to the backend".into(),
)),
}
}
async fn health(&self, _: bool) -> bool {
!self.executor_looper.is_finished() & !self.post_processor_looper.is_finished()
}
}

View File

@ -1,10 +1,16 @@
use std::path::{Path, PathBuf};
use clap::Parser;
use std::collections::HashMap;
use std::path::PathBuf;
use hf_hub::api::tokio::{Api, ApiBuilder};
use hf_hub::{Cache, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;
use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackend;
use text_generation_router::server;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};
/// App Configuration
#[derive(Parser, Debug)]
@ -48,14 +54,138 @@ struct Args {
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env)]
auth_token: Option<String>,
#[clap(long, env, help = "Path to the TensorRT-LLM Orchestrator worker")]
executor_worker: PathBuf,
#[clap(default_value = "on", long, env)]
usage_stats: usage_stats::UsageStatsLevel,
}
async fn get_tokenizer(
tokenizer_name: &str,
tokenizer_config_path: Option<&str>,
revision: Option<&str>,
) -> Option<Tokenizer> {
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
let local_path = Path::new(tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
_config_filename,
tokenizer_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.unwrap_or_else(|| "main").to_string(),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main").to_string(),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
)
}
};
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}
#[tokio::main]
@ -83,10 +213,10 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
messages_api_enabled,
max_client_batch_size,
auth_token,
executor_worker,
usage_stats,
} = args;
// Launch Tokio runtime
@ -124,18 +254,26 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
)));
}
// Run server
let tokenizer = Tokenizer::from_pretrained(
tokenizer_name.clone(),
Some(FromPretrainedParameters {
revision: revision.clone().unwrap_or(String::from("main")),
user_agent: HashMap::new(),
auth_token,
}),
// Create the backend
let tokenizer = get_tokenizer(
&tokenizer_name,
tokenizer_config_path.as_deref(),
revision.as_deref(),
)
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
.await
.expect("Failed to retrieve tokenizer implementation");
let backend = TensorRtLlmBackend::new(tokenizer, model_id, executor_worker)?;
info!("Successfully retrieved tokenizer {}", &tokenizer_name);
let backend = TensorRtLlmBackendV2::new(
tokenizer,
model_id,
executor_worker,
max_concurrent_requests,
)?;
info!("Successfully created backend");
// Run server
server::run(
backend,
max_concurrent_requests,
@ -145,7 +283,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
max_input_tokens,
max_total_tokens,
validation_workers,
None,
auth_token,
tokenizer_name,
tokenizer_config_path,
revision,
@ -155,11 +293,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
false,
None,
None,
messages_api_enabled,
true,
max_client_batch_size,
false,
false,
usage_stats,
)
.await?;
Ok(())

View File

@ -0,0 +1,22 @@
///
/// Extract the first line of the provided string reference.
/// If there is no lines in the buffer, it returns a string
/// which content is defined by the content of `fail`
/// # Arguments
///
/// * `s`: The string buffer to extract the first-line from
/// * `fail`: A string content which is returned if no lines are
/// present in `s`
///
/// returns: String
///
/// # Examples
///
/// ```
/// let s = "My name is Morgan.\n I'm working at Hugging Face.";
/// first_line(s, "No line in string");
/// ```
#[inline]
pub(crate) fn first_line(s: &str, fail: &str) -> String {
s.lines().next().unwrap_or(fail).to_string()
}

View File

@ -6,7 +6,7 @@ use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
@ -36,18 +36,14 @@ impl BackendV2 {
speculate: u32,
) -> Self {
// Infer shared state
let attention = if let Ok(attention) = std::env::var("ATTENTION") {
attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"))
} else {
Attention::Paged
};
let block_size = if attention == Attention::FlashDecoding {
256
} else {
16
let attention = std::env::var("ATTENTION").unwrap_or("paged".to_string());
let block_size = match attention.as_str() {
"flashinfer" => 1,
"flashdecoding" => 256,
"paged" => 16,
_ => unreachable!(),
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -1,12 +1,14 @@
use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};
/// Batching and inference logic
use crate::client::{
Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,
};
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
use text_generation_router::{Attention, FinishReason, PrefillToken, Token};
use text_generation_router::{FinishReason, PrefillToken, Token};
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
@ -31,27 +33,22 @@ impl BackendV3 {
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
shard_info: InfoResponse,
) -> Self {
let prefix_caching =
std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var");
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").expect("attention env var");
if shard_info.support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}
let attention: Attention = attention
.parse()
.unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`"));
let block_size = attention.block_size();
let block_size = shard_info.block_size;
let queue = Queue::new(
requires_padding,
shard_info.requires_padding,
block_size,
prefix_caching,
window_size,
speculate,
shard_info.use_prefix_caching,
shard_info.window_size,
shard_info.speculate,
max_batch_total_tokens,
shard_info.support_chunking,
);
let batching_task_notifier = Arc::new(Notify::new());
@ -63,6 +60,7 @@ impl BackendV3 {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.support_chunking,
queue.clone(),
batching_task_notifier.clone(),
));
@ -127,6 +125,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue,
notifier: Arc<Notify>,
) {
@ -147,7 +146,7 @@ pub(crate) async fn batching_task(
)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries)
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span)
.await;
let mut waiting_tokens = 1;
@ -158,10 +157,24 @@ pub(crate) async fn batching_task(
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
let prefill_token_budget =
max_batch_prefill_tokens.saturating_sub(current_tokens);
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else {
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
@ -173,24 +186,34 @@ pub(crate) async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
(min_size, max_size, max_batch_prefill_tokens)
};
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
if let Some((new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
let counter = if support_chunking {
metrics::counter!("tgi_batch_concat", "reason" => "chunking")
} else {
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
};
counter.increment(1);
}
let cached_batch = if support_chunking {
// Concat current batch to the new one
batches.pop()
} else {
// Request are waiting only if we don't support chunking
entries.iter_mut().for_each(|(_, entry)| {
// Create a new span to add the info that this entry is waiting
// because a new batch is being computed
@ -201,17 +224,23 @@ pub(crate) async fn batching_task(
// Update entry
entry.temp_span = Some(entry_waiting_span);
});
None
};
entries.extend(new_entries);
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span)
.await;
// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
entries.extend(new_entries);
batches.push(new_cached_batch);
} else if support_chunking {
// New cached batch is empty, no work left
break;
}
}
@ -244,13 +273,14 @@ pub(crate) async fn batching_task(
async fn prefill(
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
entries: &mut IntMap<u64, Entry>,
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
match client.prefill(batch, cached_batch).await {
Ok((generations, next_batch, timings)) => {
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
@ -259,6 +289,10 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", "method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")

View File

@ -158,7 +158,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
@ -217,13 +218,23 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
PrefillTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
@ -252,14 +263,16 @@ impl Client {
}
pub struct PrefillTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(Duration::from_nanos),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),

View File

@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]

View File

@ -1,6 +1,6 @@
use crate::client::{ClientError, Result};
use crate::client::Health;
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
@ -49,13 +49,13 @@ impl ShardedClient {
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
join_all(futures).await.pop().unwrap()
}
/// GRPC health check
@ -135,11 +135,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
#[allow(clippy::type_complexity)]
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
@ -194,18 +195,6 @@ impl ShardedClient {
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
@ -246,8 +235,9 @@ impl Health for ShardedClient {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
cache_len: 0,
adapter_id: None,
chunk_len: None,
};
let batch = Batch {
id: u64::MAX,
@ -256,7 +246,7 @@ impl Health for ShardedClient {
max_tokens: 2,
max_blocks: 1,
};
self.clone().prefill(batch).await?;
self.clone().prefill(batch, None).await?;
Ok(())
}
}

View File

@ -29,6 +29,14 @@ pub struct BackendInfo {
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
}
#[allow(clippy::too_many_arguments)]
@ -110,6 +118,10 @@ pub async fn connect_backend(
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
};
let backend = BackendV3::new(
@ -119,9 +131,7 @@ pub async fn connect_backend(
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
shard_info,
);
tracing::info!("Using backend V3");

View File

@ -44,6 +44,8 @@ struct Args {
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(long, env, value_enum)]
trust_remote_code: bool,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
@ -63,8 +65,6 @@ struct Args {
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
@ -101,6 +101,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
validation_workers,
api_key,
json_output,
@ -110,7 +111,6 @@ async fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
}
}
let (backend, _backend_info) = connect_backend(
let (backend, backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
)
.await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server
server::run(
backend,
@ -184,13 +184,13 @@ async fn main() -> Result<(), RouterError> {
tokenizer_name,
tokenizer_config_path,
revision,
trust_remote_code,
hostname,
port,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats,

View File

@ -4,7 +4,7 @@ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
@ -50,6 +50,7 @@ impl Queue {
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -62,6 +63,7 @@ impl Queue {
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
queue_receiver,
));
@ -87,6 +89,10 @@ impl Queue {
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if prefill_token_budget == 0 || token_budget == 0 {
return None;
};
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
// Send next batch command to the background task managing the state
@ -108,6 +114,7 @@ impl Queue {
}
// Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task(
requires_padding: bool,
block_size: u32,
@ -115,6 +122,7 @@ async fn queue_task(
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(
@ -124,6 +132,7 @@ async fn queue_task(
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
);
while let Some(cmd) = receiver.recv().await {
@ -166,12 +175,14 @@ struct State {
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
@ -184,6 +195,7 @@ impl State {
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new(
@ -199,8 +211,8 @@ impl State {
next_id: 0,
next_batch_id: 0,
block_size,
window_size,
speculate,
support_chunking,
block_allocator,
}
}
@ -287,32 +299,7 @@ impl State {
}
None
}
Some(_block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
Some(block_allocator) => {
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
@ -321,10 +308,73 @@ impl State {
entry.request.input_ids.clone()
};
Some((tokens, input_ids))
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
}
};
batch.push((id, entry, block_allocation));
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);
if chunk_len > 0 {
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), Some(chunk_len)));
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
self.entries.push_front((id, entry));
}
tracing::debug!(
"Matched budget: prefill_tokens={} == {prefill_token_budget}",
prefill_tokens + postfix_len
);
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
Some(block_allocation)
}
};
batch.push((id, entry, block_allocation, None));
if Some(batch.len()) == max_size {
break;
}
@ -342,7 +392,7 @@ impl State {
// Batch is too small
if batch.len() < min_size {
// Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() {
for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry));
}
return None;
@ -353,29 +403,7 @@ impl State {
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry");
for (id, mut entry, block_allocation, chunk_len) in batch {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
@ -427,8 +455,9 @@ impl State {
top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
prefix_len,
cache_len: prefix_len,
adapter_id: entry.request.adapter_id.clone(),
chunk_len,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -436,12 +465,6 @@ impl State {
batch_entries.insert(id, entry);
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);
@ -531,7 +554,7 @@ mod tests {
request: ValidGenerateRequest {
inputs: vec![],
input_ids: Some(Arc::new(vec![])),
input_length: 0,
input_length: 1,
add_special_tokens: true,
truncate: 0,
decoder_input_details: false,
@ -567,7 +590,7 @@ mod tests {
#[tokio::test]
async fn test_append() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0);
@ -583,7 +606,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_empty() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
@ -591,7 +614,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -623,7 +646,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, false, None, 0, 16);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -643,7 +666,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, false, None, 0, 2);
let mut state = State::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -676,14 +699,14 @@ mod tests {
#[tokio::test]
async fn test_queue_append() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _guard) = default_entry();
queue.append(entry);
}
#[tokio::test]
async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -691,7 +714,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -724,7 +747,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -740,7 +763,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -765,7 +788,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, false, None, 2, 16);
let queue = Queue::new(true, 1, false, None, 2, 16, false);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);
@ -784,7 +807,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, false, None, 0, 16);
let queue = Queue::new(false, 1, false, None, 0, 16, false);
let (entry, _) = default_entry();
queue.append(entry);

View File

@ -158,7 +158,8 @@ async fn prefill(
top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
adapter_id: None,
})
.collect();
@ -173,7 +174,7 @@ async fn prefill(
// Run prefill
let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;
// Get latency
let latency = start_time.elapsed();

View File

@ -178,6 +178,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
// Run app

View File

@ -316,6 +316,98 @@
}
}
},
"/invocations": {
"post": {
"tags": [
"Text Generation Inference"
],
"summary": "Generate tokens from Sagemaker request",
"operationId": "sagemaker_compatibility",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Generated Chat Completion",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/SagemakerResponse"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/SagemakerStreamResponse"
}
}
}
},
"422": {
"description": "Input validation error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Input validation error",
"error_type": "validation"
}
}
}
},
"424": {
"description": "Generation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Request failed during generation",
"error_type": "generation"
}
}
}
},
"429": {
"description": "Model is overloaded",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Model is overloaded",
"error_type": "overloaded"
}
}
}
},
"500": {
"description": "Incomplete generation",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ErrorResponse"
},
"example": {
"error": "Incomplete generation",
"error_type": "incomplete_generation"
}
}
}
}
}
}
},
"/metrics": {
"get": {
"tags": [
@ -1865,6 +1957,45 @@
"type": "string"
}
},
"SagemakerRequest": {
"oneOf": [
{
"$ref": "#/components/schemas/CompatGenerateRequest"
},
{
"$ref": "#/components/schemas/ChatRequest"
},
{
"$ref": "#/components/schemas/CompletionRequest"
}
]
},
"SagemakerResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/GenerateResponse"
},
{
"$ref": "#/components/schemas/ChatCompletion"
},
{
"$ref": "#/components/schemas/CompletionFinal"
}
]
},
"SagemakerStreamResponse": {
"oneOf": [
{
"$ref": "#/components/schemas/StreamResponse"
},
{
"$ref": "#/components/schemas/ChatCompletionChunk"
},
{
"$ref": "#/components/schemas/Chunk"
}
]
},
"SimpleToken": {
"type": "object",
"required": [

View File

@ -141,9 +141,7 @@ TGI can be deployed on various cloud providers for scalable and robust text gene
## Amazon SageMaker
To enable the Messages API in Amazon SageMaker you need to set the environment variable `MESSAGES_API_ENABLED=true`.
This will modify the `/invocations` route to accept Messages dictonaries consisting out of role and content. See the example below on how to deploy Llama with the new Messages API.
Amazon Sagemaker natively supports the message API:
```python
import json
@ -161,12 +159,11 @@ except ValueError:
hub = {
'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',
'SM_NUM_GPUS': json.dumps(1),
'MESSAGES_API_ENABLED': True
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri("huggingface",version="1.4.0"),
image_uri=get_huggingface_llm_image_uri("huggingface",version="2.3.2"),
env=hub,
role=role,
)

View File

@ -93,10 +93,10 @@ Options:
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA
[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]
[possible values: fp8_e4m3fn, fp8_e5m2]
```
## TRUST_REMOTE_CODE

View File

@ -8,6 +8,7 @@ Text Generation Inference enables serving optimized models. The following sectio
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)

View File

@ -26,7 +26,6 @@ As of release 2.1.2 this is an example of the data collected:
"max_top_n_tokens": 5,
"max_total_tokens": 2048,
"max_waiting_tokens": 20,
"messages_api_enabled": false,
"model_config": {
"model_type": "Bloom"
},

View File

@ -978,15 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1728381423,
"narHash": "sha256-gpHy1WtlA8ZTd8XmxsdCoDd4Z7DE7co37lH7P+nsADA=",
"lastModified": 1729531056,
"narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "93123736c97e9f7bfe825bfaf3d7de0fc9a21a1e",
"rev": "a84a90281a17b15762873845c947e5c78f5a8dd1",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "marlin-kernels-0.3.0",
"repo": "text-generation-inference-nix",
"type": "github"
}

View File

@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
@ -137,6 +137,11 @@
impure = callPackage ./nix/impure-shell.nix { inherit server; };
impureWithCuda = callPackage ./nix/impure-shell.nix {
inherit server;
withCuda = true;
};
impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {
server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };
};

View File

@ -9,13 +9,16 @@ import subprocess
import sys
import tempfile
import time
from typing import Dict, List, Optional
import docker
import pytest
import base64
from pathlib import Path
from typing import Dict, List, Optional
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from docker.errors import NotFound
from syrupy.extensions.json import JSONSnapshotExtension
from text_generation import AsyncClient
from text_generation.types import (
BestOfSequence,
@ -403,6 +406,7 @@ def launcher(event_loop):
print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug"
env["PREFILL_CHUNKING"] = "1"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
@ -501,6 +505,7 @@ def launcher(event_loop):
env = {
"LOG_LEVEL": "info,text_generation_router=debug",
"PREFILL_CHUNKING": "1",
}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
@ -642,3 +647,22 @@ def generate_multi():
return responses
return generate_load_inner
# TODO fix the server parsser to count inline image tokens correctly
@pytest.fixture
def chicken():
path = Path(__file__).parent / "images" / "chicken_on_money.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture
def cow_beach():
path = Path(__file__).parent / "images" / "cow_beach.png"
with open(path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"

View File

@ -11,27 +11,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -39,66 +39,66 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"logprob": -0.028808594,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"logprob": -0.013671875,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"logprob": -0.69921875,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"logprob": -0.0005874634,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"logprob": -0.026855469,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"logprob": -0.00020885468,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"logprob": -0.17773438,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}

View File

@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
@ -11,22 +11,22 @@
},
{
"id": 374,
"logprob": -22.96875,
"logprob": -18.0,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"logprob": -11.75,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"logprob": -2.0625,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"logprob": -6.0,
"text": "?"
}
],
@ -34,24 +34,66 @@
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"logprob": 0.0,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"id": 34564,
"logprob": -0.11279297,
"special": false,
"text": " "
"text": "Deep"
},
{
"id": 128001,
"id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " learning"
},
{
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
"special": false,
"text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
}

View File

@ -12,27 +12,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -40,68 +40,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.0047912598,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.025512695,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.012145996,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.72265625,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0005760193,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02722168,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00023651123,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.17285156,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -116,27 +116,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -144,68 +144,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -220,27 +220,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -248,68 +248,68 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
},
{
"details": {
@ -324,27 +324,27 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.21875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.95703125,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.9375,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1328125,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
@ -352,67 +352,67 @@
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.1796875,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.07897949,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.27734375,
"logprob": -0.02758789,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37402344,
"logprob": -0.013366699,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4511719,
"logprob": -0.6953125,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02909851,
"logprob": -0.0004863739,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.25854492,
"logprob": -0.02709961,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0061798096,
"logprob": -0.00022506714,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3046875,
"logprob": -0.19726562,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.5537109,
"id": 18065,
"logprob": -0.77734375,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}
]

View File

@ -10,80 +10,95 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1503906,
"text": "is"
},
{
"id": 3534,
"logprob": -9.5859375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.3945312,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.4555664,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.4777832,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8808594,
"id": 5168,
"logprob": -0.023849487,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37280273,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.26098633,
"id": 264,
"logprob": -0.14489746,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017137527,
"id": 19804,
"logprob": -0.63183594,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2695312,
"id": 302,
"logprob": -0.010314941,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9238281,
"id": 5599,
"logprob": -0.0635376,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48828125,
"id": 5168,
"logprob": -0.0028572083,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
}

View File

@ -10,42 +10,28 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 349,
"logprob": -12.0546875,
"text": "is"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 3534,
"logprob": -10.53125,
"text": "deep"
},
{
"id": 5168,
"logprob": -2.71875,
"text": "learning"
},
{
"id": 28804,
"logprob": -5.0078125,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 13,
"logprob": -0.34838867,
"special": false,
"text": "\n"
},
{
"id": 13940,
"logprob": -0.38916016,
"special": false,
"text": "``"
},
{
"id": 28832,
"logprob": 0.0,
"special": false,
"text": "`"
},
{
"id": 3371,
"logprob": -1.2529297,
"special": false,
"text": "json"
},
{
"id": 13,
"logprob": 0.0,
@ -53,37 +39,61 @@
"text": "\n"
},
{
"id": 28751,
"logprob": 0.0,
"id": 23229,
"logprob": -0.18237305,
"special": false,
"text": "{"
"text": "Deep"
},
{
"id": 13,
"id": 17504,
"logprob": 0.0,
"special": false,
"text": "\n"
"text": " Learning"
},
{
"id": 2287,
"id": 349,
"logprob": 0.0,
"special": false,
"text": " "
"text": " is"
},
{
"id": 345,
"id": 264,
"logprob": 0.0,
"special": false,
"text": " \""
"text": " a"
},
{
"id": 3134,
"logprob": -0.640625,
"id": 19804,
"logprob": 0.0,
"special": false,
"text": "request"
"text": " subset"
},
{
"id": 302,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 13253,
"logprob": -0.6040039,
"special": false,
"text": " Machine"
},
{
"id": 17504,
"logprob": 0.0,
"special": false,
"text": " Learning"
},
{
"id": 28725,
"logprob": -0.11621094,
"special": false,
"text": ","
}
],
"top_tokens": null
},
"generated_text": "Test request\n```json\n{\n \"request"
"generated_text": "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
}

View File

@ -11,82 +11,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1503906,
"text": "is"
},
{
"id": 3534,
"logprob": -9.5859375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.3945312,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.4555664,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.4777832,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13232422,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.023834229,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14416504,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63183594,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.064208984,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.0028266907,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -100,82 +115,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -189,82 +219,97 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
},
{
"details": {
@ -278,81 +323,96 @@
"text": "<s>"
},
{
"id": 3735,
"logprob": -11.0078125,
"text": "Test"
"id": 1824,
"logprob": -9.2890625,
"text": "What"
},
{
"id": 2159,
"logprob": -13.59375,
"text": "request"
"id": 349,
"logprob": -1.1425781,
"text": "is"
},
{
"id": 3534,
"logprob": -9.59375,
"text": "deep"
},
{
"id": 5168,
"logprob": -1.390625,
"text": "learning"
},
{
"id": 28804,
"logprob": -0.45532227,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.7089844,
"logprob": -0.6953125,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.68847656,
"logprob": -0.48339844,
"special": false,
"text": "\n"
},
{
"id": 28771,
"logprob": -1.9394531,
"id": 23229,
"logprob": -0.13256836,
"special": false,
"text": "#"
"text": "Deep"
},
{
"id": 3735,
"logprob": -2.8828125,
"id": 5168,
"logprob": -0.02420044,
"special": false,
"text": " Test"
"text": " learning"
},
{
"id": 2159,
"logprob": -0.37329102,
"id": 349,
"logprob": -0.13977051,
"special": false,
"text": " request"
"text": " is"
},
{
"id": 13,
"logprob": -0.2602539,
"id": 264,
"logprob": -0.14501953,
"special": false,
"text": "\n"
"text": " a"
},
{
"id": 13,
"logprob": -0.0017185211,
"id": 19804,
"logprob": -0.63134766,
"special": false,
"text": "\n"
"text": " subset"
},
{
"id": 1064,
"logprob": -2.2753906,
"id": 302,
"logprob": -0.010223389,
"special": false,
"text": "##"
"text": " of"
},
{
"id": 3735,
"logprob": -1.9316406,
"id": 5599,
"logprob": -0.06427002,
"special": false,
"text": " Test"
"text": " machine"
},
{
"id": 2159,
"logprob": -0.48217773,
"id": 5168,
"logprob": -0.002817154,
"special": false,
"text": " request"
"text": " learning"
}
],
"top_tokens": null
},
"generated_text": "\n\n# Test request\n\n## Test request"
"generated_text": "\n\nDeep learning is a subset of machine learning"
}
]

View File

@ -11,32 +11,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -44,66 +44,66 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027313232,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0623207e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5361328,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17578125,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011539459,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027680397,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}

View File

@ -5,95 +5,95 @@
"generated_tokens": 10,
"prefill": [
{
"id": 16030,
"id": 338,
"logprob": null,
"text": "is"
},
{
"id": 16030,
"logprob": -13.328125,
"text": "gradient"
},
{
"id": 26815,
"logprob": -6.4960938,
"logprob": -0.24023438,
"text": "descent"
},
{
"id": 29973,
"logprob": -5.1484375,
"logprob": -3.1386719,
"text": "?"
},
{
"id": 13,
"logprob": -4.0351562,
"text": "\n"
},
{
"id": 13,
"logprob": -5.2265625,
"logprob": -3.0878906,
"text": "\n"
}
],
"seed": 0,
"tokens": [
{
"id": 10994,
"logprob": -1.1542969,
"special": false,
"text": "Hello"
},
{
"id": 29991,
"id": 25584,
"logprob": 0.0,
"special": false,
"text": "!"
"text": "Grad"
},
{
"id": 739,
"id": 993,
"logprob": 0.0,
"special": false,
"text": " It"
"text": "ient"
},
{
"id": 2444,
"logprob": -0.42260742,
"special": false,
"text": " seems"
},
{
"id": 366,
"id": 2726,
"logprob": 0.0,
"special": false,
"text": " you"
"text": " Des"
},
{
"id": 29915,
"id": 1760,
"logprob": 0.0,
"special": false,
"text": "'"
"text": "cent"
},
{
"id": 276,
"logprob": -0.9838867,
"id": 313,
"logprob": -0.12322998,
"special": false,
"text": "re"
"text": " ("
},
{
"id": 3211,
"id": 29954,
"logprob": 0.0,
"special": false,
"text": " address"
"text": "G"
},
{
"id": 292,
"id": 29928,
"logprob": 0.0,
"special": false,
"text": "ing"
"text": "D"
},
{
"id": 263,
"logprob": -0.15124512,
"id": 29897,
"logprob": 0.0,
"special": false,
"text": " a"
"text": ")"
},
{
"id": 338,
"logprob": -0.6040039,
"special": false,
"text": " is"
},
{
"id": 385,
"logprob": -0.1796875,
"special": false,
"text": " an"
}
],
"top_tokens": null
},
"generated_text": "What is gradient descent?\n\nHello! It seems you're addressing a"
"generated_text": "What is gradient descent?\nGradient Descent (GD) is an"
}

View File

@ -12,32 +12,32 @@
},
{
"id": 338,
"logprob": -0.7133789,
"logprob": -0.6201172,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9296875,
"logprob": -13.6484375,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.048919678,
"logprob": -0.003894806,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6386719,
"text": "?"
},
{
"id": 13,
"logprob": -2.8105469,
"logprob": -6.46875,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -45,68 +45,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017028809,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028476715,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023971558,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.23840332,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.000116467476,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027871132,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6582031,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092840195,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -121,32 +121,32 @@
},
{
"id": 338,
"logprob": -0.7128906,
"logprob": -0.6113281,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.05053711,
"logprob": -0.003929138,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0058594,
"logprob": -2.625,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.484375,
"text": "\n"
},
{
"id": 13,
"logprob": -0.84521484,
"logprob": -6.6875,
"text": "\n"
}
],
@ -154,68 +154,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.018859863,
"logprob": -0.009017944,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.002822876,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.00097084045,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.0003838539,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.0001155138,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.47436523,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027036667,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.18933105,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -230,32 +230,32 @@
},
{
"id": 338,
"logprob": -0.71484375,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.671875,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.049346924,
"logprob": -0.0040016174,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6230469,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.86328125,
"logprob": -6.6875,
"text": "\n"
}
],
@ -263,68 +263,68 @@
"tokens": [
{
"id": 25584,
"logprob": -0.017196655,
"logprob": -0.008956909,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0028438568,
"logprob": -8.34465e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.026558e-05,
"logprob": -0.0003721714,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17602539,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011622906,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010406494,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.0002501011,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.00092601776,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19177246,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
},
{
"details": {
@ -339,32 +339,32 @@
},
{
"id": 338,
"logprob": -0.7192383,
"logprob": -0.609375,
"text": "is"
},
{
"id": 16030,
"logprob": -13.9375,
"logprob": -13.6640625,
"text": "gradient"
},
{
"id": 26815,
"logprob": -0.050445557,
"logprob": -0.0038967133,
"text": "descent"
},
{
"id": 29973,
"logprob": -3.0078125,
"logprob": -2.6347656,
"text": "?"
},
{
"id": 13,
"logprob": -2.8242188,
"logprob": -6.453125,
"text": "\n"
},
{
"id": 13,
"logprob": -0.8276367,
"logprob": -6.6875,
"text": "\n"
}
],
@ -372,67 +372,67 @@
"tokens": [
{
"id": 25584,
"logprob": -0.01727295,
"logprob": -0.008979797,
"special": false,
"text": "Grad"
},
{
"id": 993,
"logprob": -0.0027542114,
"logprob": -9.536743e-07,
"special": false,
"text": "ient"
},
{
"id": 26815,
"logprob": -0.023254395,
"logprob": -0.0009407997,
"special": false,
"text": " descent"
},
{
"id": 338,
"logprob": -2.0384789e-05,
"logprob": -0.00038409233,
"special": false,
"text": " is"
},
{
"id": 263,
"logprob": -0.5229492,
"id": 385,
"logprob": -0.24499512,
"special": false,
"text": " a"
},
{
"id": 937,
"logprob": -0.17126465,
"special": false,
"text": " first"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 2098,
"logprob": -0.00011301041,
"special": false,
"text": "order"
"text": " an"
},
{
"id": 13883,
"logprob": -0.48608398,
"logprob": -0.010414124,
"special": false,
"text": " optimization"
},
{
"id": 5687,
"logprob": -0.00027894974,
"logprob": -0.00024354458,
"special": false,
"text": " algorithm"
},
{
"id": 15574,
"logprob": -0.6435547,
"special": false,
"text": " commonly"
},
{
"id": 1304,
"logprob": -0.0009279251,
"special": false,
"text": " used"
},
{
"id": 297,
"logprob": -0.19470215,
"special": false,
"text": " in"
}
],
"top_tokens": null
},
"generated_text": "Gradient descent is a first-order optimization algorithm"
"generated_text": "Gradient descent is an optimization algorithm commonly used in"
}
]

View File

@ -11,57 +11,57 @@
},
{
"id": 3226,
"logprob": -8.9453125,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.8515625,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21875,
"logprob": -0.25585938,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2773438,
"logprob": -2.1972656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25195312,
"logprob": -0.2998047,
"text": "("
},
{
"id": 62,
"logprob": -4.8203125,
"logprob": -5.6445312,
"text": "L"
},
{
"id": 44,
"logprob": -3.7734375,
"logprob": -3.0839844,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8310547,
"logprob": -0.6748047,
"text": " List"
},
{
"id": 77,
"logprob": -0.22766113,
"logprob": -0.3864746,
"text": "["
},
{
"id": 1808,
"logprob": -0.46240234,
"logprob": -0.9355469,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0234375,
"logprob": -2.5371094,
"text": "]):"
}
],
@ -69,7 +69,7 @@
"tokens": [
{
"id": 284,
"logprob": -0.04626465,
"logprob": -1.1679688,
"special": false,
"text": "\n "
},

View File

@ -11,57 +11,57 @@
},
{
"id": 3226,
"logprob": -8.9453125,
"logprob": -9.015625,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.859375,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21984863,
"logprob": -0.25585938,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2861328,
"logprob": -2.2304688,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25219727,
"logprob": -0.29760742,
"text": "("
},
{
"id": 62,
"logprob": -4.8007812,
"logprob": -5.6796875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7949219,
"logprob": -3.0742188,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8046875,
"logprob": -0.67626953,
"text": " List"
},
{
"id": 77,
"logprob": -0.22424316,
"logprob": -0.38842773,
"text": "["
},
{
"id": 1808,
"logprob": -0.46191406,
"logprob": -0.9165039,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0253906,
"logprob": -2.5527344,
"text": "]):"
}
],
@ -69,7 +69,7 @@
"tokens": [
{
"id": 284,
"logprob": 0.0,
"logprob": -0.048583984,
"special": false,
"text": "\n "
},

View File

@ -26,7 +26,7 @@
},
{
"id": 259,
"logprob": -0.46948242,
"logprob": -0.47070312,
"special": false,
"text": " "
},
@ -38,7 +38,7 @@
},
{
"id": 35622,
"logprob": -0.79589844,
"logprob": -0.796875,
"special": false,
"text": " cloud"
},
@ -75,5 +75,5 @@
],
"top_tokens": null
},
"generated_text": "Why is the sky blue?blue sky, clouds and clouds"
"generated_text": "Why is the sky blue?blue sky , clouds and clouds"
}

View File

@ -4,7 +4,9 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_fp8_kv_cache_handle(launcher):
with launcher(
"meta-llama/Meta-Llama-3-8B", num_shard=2, kv_cache_dtype="fp8_e5m2"
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
num_shard=2,
kv_cache_dtype="fp8_e4m3fn",
) as handle:
yield handle
@ -25,7 +27,7 @@ async def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snaps
assert (
response.generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@ -69,7 +71,7 @@ async def test_flash_llama_fp8_kv_cache_load(
assert len(responses) == 4
assert (
responses[0].generated_text
== " Deep learning is a subset of machine learning that is"
== " Deep learning is a subset of machine learning that involves"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -3,7 +3,11 @@ import pytest
@pytest.fixture(scope="module")
def flash_mixtral_gptq_handle(launcher):
with launcher("TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ", num_shard=2) as handle:
with launcher(
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
revision="gptq-4bit-128g-actorder_True",
num_shard=2,
) as handle:
yield handle
@ -16,7 +20,12 @@ async def flash_mixtral_gptq(flash_mixtral_gptq_handle):
@pytest.mark.asyncio
async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
"What is deep learning?", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert (
response.generated_text == "\n\nDeep learning is a subset of machine learning"
)
assert response == response_snapshot
@ -25,7 +34,7 @@ async def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):
@pytest.mark.asyncio
async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):
response = await flash_mixtral_gptq.generate(
"Test request",
"What is deep learning?",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
@ -41,6 +50,10 @@ async def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapsh
)
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is deep learning?\nDeep Learning is a subset of Machine Learning,"
)
assert response == response_snapshot
@ -49,10 +62,14 @@ async def test_flash_mixtral_gptq_load(
flash_mixtral_gptq, generate_load, response_snapshot
):
responses = await generate_load(
flash_mixtral_gptq, "Test request", max_new_tokens=10, n=4
flash_mixtral_gptq, "What is deep learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert (
responses[0].generated_text
== "\n\nDeep learning is a subset of machine learning"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]
), f"{[r.generated_text for r in responses]}"

View File

@ -1,5 +1,4 @@
import pytest
import base64
@pytest.fixture(scope="module")
@ -20,24 +19,11 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
cow = get_cow_beach()
inputs = f"![]({cow})Where is the cow standing?\n"
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):
inputs = f"![]({cow_beach})Where is the cow standing?\n"
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
assert response.generated_text == "beach"
@ -47,9 +33,9 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_flash_pali_gemma_two_images(
flash_pali_gemma, response_snapshot, chicken, cow_beach
):
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,

View File

@ -25,7 +25,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert response == response_snapshot
@ -33,7 +33,7 @@ async def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):
@pytest.mark.asyncio
async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
response = await flash_phi35_moe.generate(
"What is gradient descent?\n\n",
"What is gradient descent?\n",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
@ -51,7 +51,7 @@ async def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):
assert response.details.generated_tokens == 10
assert (
response.generated_text
== "What is gradient descent?\n\nHello! It seems you're addressing a"
== "What is gradient descent?\nGradient Descent (GD) is an"
)
assert response == response_snapshot
@ -66,7 +66,7 @@ async def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_sna
assert responses[0].details.generated_tokens == 10
assert (
responses[0].generated_text
== "Gradient descent is a first-order optimization algorithm"
== "Gradient descent is an optimization algorithm commonly used in"
)
assert all(
[r.generated_text == responses[0].generated_text for r in responses]

View File

@ -25,7 +25,6 @@ async def llama_grammar(llama_grammar_handle):
@pytest.mark.release
@pytest.mark.asyncio
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
class Weather(BaseModel):
unit: str
temperature: List[int]

View File

@ -1,5 +1,4 @@
import pytest
import base64
@pytest.fixture(scope="module")
@ -16,22 +15,8 @@ async def idefics(idefics_handle):
return idefics_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot):
chicken = get_chicken()
async def test_idefics(idefics, response_snapshot, chicken):
response = await idefics.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
@ -48,9 +33,7 @@ async def test_idefics(idefics, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):
response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
@ -63,8 +46,7 @@ async def test_idefics_two_images(idefics, response_snapshot):
@pytest.mark.release
@pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken()
async def test_idefics_load(idefics, generate_load, response_snapshot, chicken):
responses = await generate_load(
idefics,
f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,18 +1,4 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
@ -31,8 +17,9 @@ async def flash_idefics2_next(flash_idefics2_next_handle):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
async def test_flash_idefics2_next_simple(
flash_idefics2_next, response_snapshot, chicken
):
response = await flash_idefics2_next.generate(
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
max_new_tokens=10,
@ -46,9 +33,9 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
async def test_flash_idefics2_two_images(
flash_idefics2_next, response_snapshot, chicken, cow_beach
):
response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
@ -87,9 +74,8 @@ async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snap
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_load(
flash_idefics2_next, generate_load, response_snapshot
flash_idefics2_next, generate_load, response_snapshot, chicken
):
chicken = get_chicken()
responses = await generate_load(
flash_idefics2_next,
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",

View File

@ -1,12 +1,4 @@
import pytest
import base64
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.fixture(scope="module")
@ -29,8 +21,7 @@ async def flash_llava_next(flash_llava_next_handle):
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot):
chicken = get_chicken()
async def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):
response = await flash_llava_next.generate(
f"User:![]({chicken})Can you tell me a very short story based on the image?",
max_new_tokens=10,
@ -70,9 +61,8 @@ async def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llava_next_load(
flash_llava_next, generate_load, response_snapshot
flash_llava_next, generate_load, response_snapshot, chicken
):
chicken = get_chicken()
responses = await generate_load(
flash_llava_next,
f"User:![]({chicken})Can you tell me a very short story based on the image?",

View File

@ -1,5 +1,4 @@
import pytest
import base64
import asyncio
@ -15,22 +14,8 @@ async def mllama(mllama_handle):
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,

View File

@ -68,7 +68,7 @@ fn get_config(
fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {
let compute_capability = gpu::get_cuda_capability();
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut prefix_caching: Option<String> = std::env::var("PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config {
if prefix_caching.is_none() {
@ -94,7 +94,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if attention.is_none() {
@ -124,6 +124,10 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
}
if attention == Some("paged".to_string()) && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching on paged attention");
prefix_caching = Some("0".to_string());
}
let attention = attention.unwrap_or("flashinfer".to_string());
let prefix_caching = prefix_caching.unwrap_or("true".to_string());
@ -303,6 +307,9 @@ impl std::fmt::Display for Dtype {
#[derive(Clone, Copy, Debug, ValueEnum)]
enum KVCacheDtype {
#[clap(name = "fp8_e4m3fn")]
Fp8e4m3fn,
#[clap(name = "fp8_e5m2")]
Fp8e5m2,
}
@ -310,6 +317,9 @@ enum KVCacheDtype {
impl std::fmt::Display for KVCacheDtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
KVCacheDtype::Fp8e4m3fn => {
write!(f, "fp8_e4m3fn")
}
KVCacheDtype::Fp8e5m2 => {
write!(f, "fp8_e5m2")
}
@ -420,7 +430,7 @@ struct Args {
/// Specify the dtype for the key-value cache. When this option is not provided,
/// the dtype of the model is used (typically `float16` or `bfloat16`). Currently
/// the only supported value is `fp8_e5m2` on CUDA.
/// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.
#[clap(long, env, value_enum)]
kv_cache_dtype: Option<KVCacheDtype>,
@ -1094,6 +1104,8 @@ fn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {
}
}
}
} else {
break;
}
}
}
@ -1497,6 +1509,10 @@ fn spawn_webserver(
router_args.push(revision.to_string())
}
if args.trust_remote_code {
router_args.push("--trust-remote-code".to_string());
}
if args.json_output {
router_args.push("--json-output".to_string());
}
@ -1678,7 +1694,7 @@ fn main() -> Result<(), LauncherError> {
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("USE_PREFIX_CACHING", prefix_caching);
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let max_input_tokens = {
@ -1729,12 +1745,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
@ -1788,12 +1798,6 @@ fn main() -> Result<(), LauncherError> {
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",

View File

@ -1,7 +1,12 @@
{
lib,
mkShell,
black,
cmake,
isort,
ninja,
which,
cudaPackages,
openssl,
pkg-config,
protobuf,
@ -11,14 +16,17 @@
ruff,
rust-bin,
server,
# Enable dependencies for building CUDA packages. Useful for e.g.
# developing marlin/moe-kernels in-place.
withCuda ? false,
}:
mkShell {
buildInputs =
nativeBuildInputs =
[
black
isort
openssl.dev
pkg-config
(rust-bin.stable.latest.default.override {
extensions = [
@ -31,6 +39,19 @@ mkShell {
redocly
ruff
]
++ (lib.optionals withCuda [
cmake
ninja
which
# For most Torch-based extensions, setting CUDA_HOME is enough, but
# some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.
cudaPackages.cuda_nvcc
]);
buildInputs =
[
openssl.dev
]
++ (with python3.pkgs; [
venvShellHook
docker
@ -40,10 +61,29 @@ mkShell {
pytest
pytest-asyncio
syrupy
]);
])
++ (lib.optionals withCuda (
with cudaPackages;
[
cuda_cccl
cuda_cudart
cuda_nvrtc
cuda_nvtx
cuda_profiler_api
cudnn
libcublas
libcusolver
libcusparse
]
));
inputsFrom = [ server ];
env = lib.optionalAttrs withCuda {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" python3.pkgs.torch.cudaCapabilities;
};
venvDir = "./.venv";
postVenvCreation = ''
@ -51,6 +91,7 @@ mkShell {
( cd server ; python -m pip install --no-dependencies -e . )
( cd clients/python ; python -m pip install --no-dependencies -e . )
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
export PATH=$PATH:~/.cargo/bin

View File

@ -34,6 +34,10 @@ message InfoResponse {
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
bool use_prefix_caching = 7;
string attention_impl = 8;
uint32 block_size = 9;
}
/// Empty request
@ -135,10 +139,14 @@ message Request {
repeated uint32 slots = 10;
/// LORA adapter index
optional string adapter_id = 11;
/// Prefix length that can be retrieved from the KV cache.
uint32 prefix_len = 12;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 14;
}
message Batch {
@ -163,6 +171,8 @@ message CachedBatch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {
@ -220,6 +230,8 @@ message FilterBatchResponse {
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
@ -233,6 +245,8 @@ message PrefillResponse {
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message DecodeRequest {

View File

@ -150,6 +150,7 @@ pub enum Config {
Idefics2(Idefics2),
Ssm,
GptBigcode,
Granite,
Santacoder,
Bloom,
Mpt,

View File

@ -8,6 +8,7 @@ pub mod validation;
mod kserve;
pub mod logging;
mod sagemaker;
pub mod usage_stats;
mod vertex;
@ -18,45 +19,6 @@ use tracing::warn;
use utoipa::ToSchema;
use validation::Validation;
#[derive(PartialEq)]
pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
}
impl Attention {
pub fn block_size(&self) -> u32 {
match self {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
}
}
}
#[derive(Debug)]
pub struct ParseError;
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cannot parse attention value")
}
}
impl std::error::Error for ParseError {}
impl std::str::FromStr for Attention {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
_ => Err(ParseError),
}
}
}
/// Hub type
#[derive(Clone, Debug, Deserialize)]
pub struct HubModelInfo {

View File

@ -1,748 +0,0 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
use opentelemetry::sdk::trace::Sampler;
use opentelemetry::sdk::Resource;
use opentelemetry::{global, KeyValue};
use opentelemetry_otlp::WithExportConfig;
use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use text_generation_router::config::Config;
use text_generation_router::usage_stats;
use text_generation_router::{
server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig,
};
use thiserror::Error;
use tokenizers::{processors::template::TemplateProcessing, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "5", long, env)]
max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)]
max_input_tokens: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
#[clap(long, env)]
tokenizer_config_path: Option<String>,
#[clap(long, env)]
revision: Option<String>,
#[clap(default_value = "2", long, env)]
validation_workers: usize,
#[clap(long, env)]
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(default_value = "text-generation-inference.router", long, env)]
otlp_service_name: String,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
api_key: Option<String>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
#[clap(long, env, default_value_t)]
disable_usage_stats: bool,
#[clap(long, env, default_value_t)]
disable_crash_reports: bool,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
tokenizer_config_path,
revision,
validation_workers,
json_output,
otlp_endpoint,
otlp_service_name,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
command,
} = args;
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Tokenizer instance
// This will only be used to validate payloads
let local_path = Path::new(&tokenizer_name);
// Shared API builder initialization
let api_builder = || {
let mut builder = ApiBuilder::new()
.with_progress(false)
.with_token(authorization_token);
if let Ok(cache_dir) = std::env::var("HUGGINGFACE_HUB_CACHE") {
builder = builder.with_cache_dir(cache_dir.into());
}
builder
};
// Decide if we need to use the API based on the revision and local path
let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();
// Initialize API if needed
#[derive(Clone)]
enum Type {
Api(Api),
Cache(Cache),
None,
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {
tracing::info!("Using the Hugging Face API");
match api_builder().build() {
Ok(api) => Type::Api(api),
Err(_) => {
tracing::warn!("Unable to build the Hugging Face API");
Type::None
}
}
}
} else {
Type::None
};
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Some(local_path.join("config.json")),
Some(local_path.join("tokenizer_config.json")),
Some(local_path.join("preprocessor_config.json")),
Some(local_path.join("processor_config.json")),
None,
),
Type::Api(api) => {
let api_repo = api.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
let tokenizer_filename = match api_repo.get("tokenizer.json").await {
Ok(tokenizer_filename) => Some(tokenizer_filename),
Err(_) => get_base_tokenizer(&api, &api_repo).await,
};
let config_filename = api_repo.get("config.json").await.ok();
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok();
let preprocessor_config_filename = api_repo.get("preprocessor_config.json").await.ok();
let processor_config_filename = api_repo.get("processor_config.json").await.ok();
let model_info = if let Some(model_info) = get_model_info(&api_repo).await {
Some(model_info)
} else {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
None
};
(
tokenizer_filename,
config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
model_info,
)
}
Type::Cache(cache) => {
let repo = cache.repo(Repo::with_revision(
tokenizer_name.to_string(),
RepoType::Model,
revision.clone().unwrap_or_else(|| "main".to_string()),
));
(
repo.get("tokenizer.json"),
repo.get("config.json"),
repo.get("tokenizer_config.json"),
repo.get("preprocessor_config.json"),
repo.get("processor_config.json"),
None,
)
}
};
let config: Option<Config> = config_filename.and_then(|filename| {
std::fs::read_to_string(filename)
.ok()
.as_ref()
.and_then(|c| {
let config: Result<Config, _> = serde_json::from_str(c);
if let Err(err) = &config {
tracing::warn!("Could not parse config {err:?}");
}
config.ok()
})
});
let model_info = model_info.unwrap_or_else(|| HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
});
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
{
HubTokenizerConfig::from_file(filename)
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});
let tokenizer_class = tokenizer_config.tokenizer_class.clone();
let tokenizer: Option<Tokenizer> = tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
});
let preprocessor_config =
preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);
let processor_config = processor_config_filename
.and_then(HubProcessorConfig::from_file)
.unwrap_or_default();
tracing::info!("Using config {config:?}");
if tokenizer.is_none() {
tracing::warn!("Could not find a fast tokenizer implementation for {tokenizer_name}");
tracing::warn!("Rust input length validation and truncation is disabled");
}
// if pipeline-tag == text-generation we default to return_full_text = true
let compat_return_full_text = match &model_info.pipeline_tag {
None => {
tracing::warn!("no pipeline tag found for model {tokenizer_name}");
true
}
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation",
};
// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Only send usage stats when TGI is run in container and the function returns Some
let is_container = matches!(usage_stats::is_container(), Ok(true));
let user_agent = if !disable_usage_stats && is_container {
let reduced_args = usage_stats::Args::new(
config.clone(),
tokenizer_class,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
disable_usage_stats,
disable_crash_reports,
);
Some(usage_stats::UserAgent::new(reduced_args))
} else {
None
};
if let Some(ref ua) = user_agent {
let start_event =
usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);
tokio::spawn(async move {
start_event.send().await;
});
};
// Run server
let result = server::run(
master_shard_uds_path,
model_info,
compat_return_full_text,
max_concurrent_requests,
max_best_of,
max_stop_sequences,
max_top_n_tokens,
max_input_tokens,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
tokenizer,
config,
validation_workers,
addr,
cors_allow_origin,
api_key,
ngrok,
ngrok_authtoken,
ngrok_edge,
tokenizer_config,
preprocessor_config,
processor_config,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await;
match result {
Ok(_) => {
if let Some(ref ua) = user_agent {
let stop_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Stop,
None,
);
stop_event.send().await;
};
Ok(())
}
Err(e) => {
if let Some(ref ua) = user_agent {
if !disable_crash_reports {
let error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some(e.to_string()),
);
error_event.send().await;
} else {
let unknow_error_event = usage_stats::UsageStatsEvent::new(
ua.clone(),
usage_stats::EventType::Error,
Some("unknow_error".to_string()),
);
unknow_error_event.send().await;
}
};
Err(RouterError::WebServer(e))
}
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
/// - otlp_service_name service name to appear in APM
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {
let mut layers = Vec::new();
// STDOUT/STDERR layer
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_ansi(ansi)
.with_line_number(true);
let fmt_layer = match json_output {
true => fmt_layer.json().flatten_event(true).boxed(),
false => fmt_layer.boxed(),
};
layers.push(fmt_layer);
// OpenTelemetry tracing layer
if let Some(otlp_endpoint) = otlp_endpoint {
global::set_text_map_propagator(TraceContextPropagator::new());
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(
opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint(otlp_endpoint),
)
.with_trace_config(
trace::config()
.with_resource(Resource::new(vec![KeyValue::new(
"service.name",
otlp_service_name,
)]))
.with_sampler(Sampler::AlwaysOn),
)
.install_batch(opentelemetry::runtime::Tokio);
if let Ok(tracer) = tracer {
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());
init_tracing_opentelemetry::init_propagator().unwrap();
};
}
// Filter events with LOG_LEVEL
let varname = "LOG_LEVEL";
let env_filter = if let Ok(log_level) = std::env::var(varname) {
// Override to avoid simple logs to be spammed with tokio level informations
let log_level = match &log_level[..] {
"warn" => "text_generation_launcher=warn,text_generation_router=warn",
"info" => "text_generation_launcher=info,text_generation_router=info",
"debug" => "text_generation_launcher=debug,text_generation_router=debug",
log_level => log_level,
};
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy(log_level)
} else {
EnvFilter::new("info")
};
tracing_subscriber::registry()
.with(env_filter)
.with(layers)
.init();
}
/// get model info from the Huggingface Hub
pub async fn get_model_info(api: &ApiRepo) -> Option<HubModelInfo> {
let response = api.info_request().send().await.ok()?;
if response.status().is_success() {
let hub_model_info: HubModelInfo =
serde_json::from_str(&response.text().await.ok()?).ok()?;
if let Some(sha) = &hub_model_info.sha {
tracing::info!(
"Serving revision {sha} of model {}",
hub_model_info.model_id
);
}
Some(hub_model_info)
} else {
None
}
}
/// get base tokenizer
pub async fn get_base_tokenizer(api: &Api, api_repo: &ApiRepo) -> Option<PathBuf> {
let config_filename = api_repo.get("config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of `User`.
let config: serde_json::Value = serde_json::from_reader(reader).ok()?;
if let Some(serde_json::Value::String(base_model_id)) = config.get("base_model_name_or_path") {
let api_base_repo = api.repo(Repo::with_revision(
base_model_id.to_string(),
RepoType::Model,
"main".to_string(),
));
api_base_repo.get("tokenizer.json").await.ok()
} else {
None
}
}
/// get tokenizer_config from the Huggingface Hub
pub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {
let tokenizer_config_filename = api_repo.get("tokenizer_config.json").await.ok()?;
// Open the file in read-only mode with buffer.
let file = File::open(tokenizer_config_filename).ok()?;
let reader = BufReader::new(file);
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)
.map_err(|e| {
tracing::warn!("Unable to parse tokenizer config: {}", e);
e
})
.ok()?;
Some(tokenizer_config)
}
/// Create a post_processor for the LlamaTokenizer
pub fn create_post_processor(
tokenizer: &Tokenizer,
tokenizer_config: &HubTokenizerConfig,
) -> Result<TemplateProcessing, tokenizers::processors::template::TemplateProcessingBuilderError> {
let add_bos_token = tokenizer_config.add_bos_token.unwrap_or(true);
let add_eos_token = tokenizer_config.add_eos_token.unwrap_or(false);
let bos_token = tokenizer_config.bos_token.as_ref();
let eos_token = tokenizer_config.eos_token.as_ref();
if add_bos_token && bos_token.is_none() {
panic!("add_bos_token = true but bos_token is None");
}
if add_eos_token && eos_token.is_none() {
panic!("add_eos_token = true but eos_token is None");
}
let mut single = Vec::new();
let mut pair = Vec::new();
let mut special_tokens = Vec::new();
if add_bos_token {
if let Some(bos) = bos_token {
let bos_token_id = tokenizer
.token_to_id(bos.as_str())
.expect("Should have found the bos token id");
special_tokens.push((bos.as_str(), bos_token_id));
single.push(format!("{}:0", bos.as_str()));
pair.push(format!("{}:0", bos.as_str()));
}
}
single.push("$A:0".to_string());
pair.push("$A:0".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
let eos_token_id = tokenizer
.token_to_id(eos.as_str())
.expect("Should have found the eos token id");
special_tokens.push((eos.as_str(), eos_token_id));
single.push(format!("{}:0", eos.as_str()));
pair.push(format!("{}:0", eos.as_str()));
}
}
if add_bos_token {
if let Some(bos) = bos_token {
pair.push(format!("{}:1", bos.as_str()));
}
}
pair.push("$B:1".to_string());
if add_eos_token {
if let Some(eos) = eos_token {
pair.push(format!("{}:1", eos.as_str()));
}
}
let post_processor = TemplateProcessing::builder()
.try_single(single)?
.try_pair(pair)?
.special_tokens(special_tokens)
.build()?;
Ok(post_processor)
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Argument validation error: {0}")]
ArgumentValidation(String),
#[error("WebServer error: {0}")]
WebServer(#[from] server::WebServerError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
use text_generation_router::TokenizerConfigToken;
#[test]
fn test_create_post_processor() {
let tokenizer_config = HubTokenizerConfig {
add_bos_token: None,
add_eos_token: None,
bos_token: Some(TokenizerConfigToken::String("<s>".to_string())),
eos_token: Some(TokenizerConfigToken::String("</s>".to_string())),
chat_template: None,
tokenizer_class: None,
completion_template: None,
};
let tokenizer =
Tokenizer::from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", None).unwrap();
let post_processor = create_post_processor(&tokenizer, &tokenizer_config).unwrap();
let expected = TemplateProcessing::builder()
.try_single("<s>:0 $A:0")
.unwrap()
.try_pair("<s>:0 $A:0 <s>:1 $B:1")
.unwrap()
.special_tokens(vec![("<s>".to_string(), 1)])
.build()
.unwrap();
assert_eq!(post_processor, expected);
}
}

82
router/src/sagemaker.rs Normal file
View File

@ -0,0 +1,82 @@
use crate::infer::Infer;
use crate::server::{chat_completions, compat_generate, completions, ComputeType};
use crate::{
ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,
CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,
};
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::response::Response;
use axum::Json;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use utoipa::ToSchema;
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerRequest {
Generate(CompatGenerateRequest),
Chat(ChatRequest),
Completion(CompletionRequest),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerResponse {
Generate(GenerateResponse),
Chat(ChatCompletion),
Completion(CompletionFinal),
}
// Used for OpenAPI specs
#[allow(dead_code)]
#[derive(Serialize, ToSchema)]
#[serde(untagged)]
pub(crate) enum SagemakerStreamResponse {
Generate(StreamResponse),
Chat(ChatCompletionChunk),
Completion(Chunk),
}
/// Generate tokens from Sagemaker request
#[utoipa::path(
post,
tag = "Text Generation Inference",
path = "/invocations",
request_body = SagemakerRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = SagemakerResponse),
("text/event-stream" = SagemakerStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation", "error_type": "generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error", "error_type": "validation"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation", "error_type": "incomplete_generation"})),
)
)]
#[instrument(skip_all)]
pub(crate) async fn sagemaker_compatibility(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
info: Extension<Info>,
Json(req): Json<SagemakerRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
match req {
SagemakerRequest::Generate(req) => {
compat_generate(default_return_full_text, infer, compute_type, Json(req)).await
}
SagemakerRequest::Chat(req) => chat_completions(infer, compute_type, info, Json(req)).await,
SagemakerRequest::Completion(req) => {
completions(infer, compute_type, info, Json(req)).await
}
}
}

View File

@ -8,6 +8,10 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::logging::trace_context_middleware;
use crate::sagemaker::{
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
__path_sagemaker_compatibility,
};
use crate::validation::ValidationError;
use crate::vertex::vertex_compatibility;
use crate::ChatTokenizeResponse;
@ -85,7 +89,7 @@ example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn compat_generate(
pub(crate) async fn compat_generate(
Extension(default_return_full_text): Extension<bool>,
infer: Extension<Infer>,
compute_type: Extension<ComputeType>,
@ -694,7 +698,7 @@ time_per_token,
seed,
)
)]
async fn completions(
pub(crate) async fn completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1223,7 +1227,7 @@ time_per_token,
seed,
)
)]
async fn chat_completions(
pub(crate) async fn chat_completions(
Extension(infer): Extension<Infer>,
Extension(compute_type): Extension<ComputeType>,
Extension(info): Extension<Info>,
@ -1539,11 +1543,13 @@ completions,
tokenize,
metrics,
openai_get_model_info,
sagemaker_compatibility,
),
components(
schemas(
Info,
CompatGenerateRequest,
SagemakerRequest,
GenerateRequest,
GrammarType,
ChatRequest,
@ -1566,6 +1572,8 @@ ChatCompletionTopLogprob,
ChatCompletion,
CompletionRequest,
CompletionComplete,
SagemakerResponse,
SagemakerStreamResponse,
Chunk,
Completion,
CompletionFinal,
@ -1627,13 +1635,13 @@ pub async fn run(
tokenizer_name: String,
tokenizer_config_path: Option<String>,
revision: Option<String>,
trust_remote_code: bool,
hostname: String,
port: u16,
cors_allow_origin: Option<Vec<String>>,
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: usage_stats::UsageStatsLevel,
@ -1787,10 +1795,13 @@ pub async fn run(
let auto = transformers.getattr("AutoTokenizer")?;
let from_pretrained = auto.getattr("from_pretrained")?;
let args = (tokenizer_name.to_string(),);
let kwargs = [(
let kwargs = [
(
"revision",
revision.clone().unwrap_or_else(|| "main".to_string()),
)]
(revision.clone().unwrap_or_else(|| "main".to_string())).into_py(py),
),
("trust_remote_code", trust_remote_code.into_py(py)),
]
.into_py_dict_bound(py);
let tokenizer = from_pretrained.call(args, Some(&kwargs))?;
let save = tokenizer.getattr("save_pretrained")?;
@ -1862,7 +1873,6 @@ pub async fn run(
// max_batch_size,
revision.clone(),
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,
@ -1904,7 +1914,6 @@ pub async fn run(
ngrok,
_ngrok_authtoken,
_ngrok_edge,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
model_info,
@ -1964,7 +1973,6 @@ async fn start(
ngrok: bool,
_ngrok_authtoken: Option<String>,
_ngrok_edge: Option<String>,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
model_info: HubModelInfo,
@ -2279,6 +2287,7 @@ async fn start(
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/completions", post(completions))
.route("/vertex", post(vertex_compatibility))
.route("/invocations", post(sagemaker_compatibility))
.route("/tokenize", post(tokenize));
if let Some(api_key) = api_key {
@ -2314,13 +2323,6 @@ async fn start(
.route("/metrics", get(metrics))
.route("/v1/models", get(openai_get_model_info));
// Conditional AWS Sagemaker route
let aws_sagemaker_route = if messages_api_enabled {
Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
} else {
Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
};
let compute_type =
ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
@ -2328,8 +2330,7 @@ async fn start(
let mut app = Router::new()
.merge(swagger_ui)
.merge(base_routes)
.merge(info_routes)
.merge(aws_sagemaker_route);
.merge(info_routes);
#[cfg(feature = "google")]
{

View File

@ -93,7 +93,6 @@ pub struct Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -117,7 +116,6 @@ impl Args {
// max_batch_size: Option<usize>,
revision: Option<String>,
validation_workers: usize,
messages_api_enabled: bool,
disable_grammar_support: bool,
max_client_batch_size: usize,
usage_stats_level: UsageStatsLevel,
@ -138,7 +136,6 @@ impl Args {
// max_batch_size,
revision,
validation_workers,
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
usage_stats_level,

View File

@ -1,9 +1,6 @@
use crate::infer::Infer;
use crate::server::{generate_internal, ComputeType};
use crate::{
ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest, GrammarType, Message,
StreamOptions, Tool, ToolChoice,
};
use crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
@ -22,162 +19,12 @@ pub(crate) struct GenerateVertexInstance {
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexChat {
messages: Vec<Message>,
// Messages is ignored there.
#[serde(default)]
parameters: VertexParameters,
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]
#[cfg_attr(test, derive(Debug, PartialEq))]
pub(crate) struct VertexParameters {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: Option<String>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
#[serde(default)]
#[schema(example = "1.0")]
pub frequency_penalty: Option<f32>,
/// UNUSED
/// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens
/// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,
/// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should
/// result in a ban or exclusive selection of the relevant token.
#[serde(default)]
pub logit_bias: Option<Vec<f32>>,
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each
/// output token returned in the content of message.
#[serde(default)]
#[schema(example = "false")]
pub logprobs: Option<bool>,
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[schema(example = "5")]
pub top_logprobs: Option<u32>,
/// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)]
#[schema(example = "32")]
pub max_tokens: Option<u32>,
/// UNUSED
/// How many chat completion choices to generate for each input message. Note that you will be charged based on the
/// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.
#[serde(default)]
#[schema(nullable = true, example = "2")]
pub n: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[schema(nullable = true, example = 0.1)]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stop: Option<Vec<String>>,
#[serde(default = "bool::default")]
pub stream: bool,
#[schema(nullable = true, example = 42)]
pub seed: Option<u64>,
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while
/// lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(default)]
#[schema(nullable = true, example = 1.0)]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the
/// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
#[serde(default)]
#[schema(nullable = true, example = 0.95)]
pub top_p: Option<f32>,
/// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of
/// functions the model may generate JSON inputs for.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tools: Option<Vec<Tool>>,
/// A prompt to be appended before the tools
#[serde(default)]
#[schema(
nullable = true,
example = "Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables."
)]
pub tool_prompt: Option<String>,
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: ToolChoice,
/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub response_format: Option<GrammarType>,
/// A guideline to be used in the chat_template
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
impl From<VertexChat> for ChatRequest {
fn from(val: VertexChat) -> Self {
Self {
messages: val.messages,
frequency_penalty: val.parameters.frequency_penalty,
guideline: val.parameters.guideline,
logit_bias: val.parameters.logit_bias,
logprobs: val.parameters.logprobs,
max_tokens: val.parameters.max_tokens,
model: val.parameters.model,
n: val.parameters.n,
presence_penalty: val.parameters.presence_penalty,
response_format: val.parameters.response_format,
seed: val.parameters.seed,
stop: val.parameters.stop,
stream_options: val.parameters.stream_options,
stream: val.parameters.stream,
temperature: val.parameters.temperature,
tool_choice: val.parameters.tool_choice,
tool_prompt: val.parameters.tool_prompt,
tools: val.parameters.tools,
top_logprobs: val.parameters.top_logprobs,
top_p: val.parameters.top_p,
}
}
}
#[derive(Clone, Deserialize, ToSchema)]
#[cfg_attr(test, derive(Debug, PartialEq))]
#[serde(untagged)]
pub(crate) enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(VertexChat),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)]
@ -263,9 +110,8 @@ pub(crate) async fn vertex_compatibility(
},
},
VertexInstance::Chat(instance) => {
let chat_request: ChatRequest = instance.into();
let (generate_request, _using_tools): (GenerateRequest, bool) =
chat_request.try_into_generate(&infer)?;
instance.try_into_generate(&infer)?;
generate_request
}
};
@ -311,35 +157,15 @@ mod tests {
#[test]
fn vertex_deserialization() {
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
});
let _request: VertexChat = serde_json::from_value(string).expect("Can deserialize");
let string = serde_json::json!({
"instances": [
{
"messages": [{"role": "user", "content": "What's Deep Learning?"}],
"parameters": {
"max_tokens": 128,
"top_p": 0.95,
"temperature": 0.7
}
}
]
});
@ -347,18 +173,16 @@ mod tests {
assert_eq!(
request,
VertexRequest {
instances: vec![VertexInstance::Chat(VertexChat {
instances: vec![VertexInstance::Chat(ChatRequest {
messages: vec![Message {
role: "user".to_string(),
content: MessageContent::SingleText("What's Deep Learning?".to_string()),
name: None,
},],
parameters: VertexParameters {
max_tokens: Some(128),
top_p: Some(0.95),
temperature: Some(0.7),
..Default::default()
}
})]
}
);

View File

@ -31,7 +31,7 @@ install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
pip install -e ".[bnb,marlin,moe]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm

1379
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -1,5 +1,5 @@
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
@ -10,7 +10,7 @@ googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version <
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
@ -38,14 +38,14 @@ pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
rich==13.9.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -2,7 +2,7 @@ import pytest
import os
from text_generation_server.pb import generate_pb2
os.environ["USE_PREFIX_CACHING"] = "1"
os.environ["PREFIX_CACHING"] = "1"
os.environ["ATTENTION"] = "flashinfer"

View File

@ -31,6 +31,7 @@ class Dtype(str, Enum):
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"

View File

@ -9,6 +9,9 @@ from typing import Callable, Any
class ExceptionInterceptor(AsyncServerInterceptor):
def __init__(self, shutdown_callback):
self.shutdown_callback = shutdown_callback
async def intercept(
self,
method: Callable,
@ -25,7 +28,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
# Runtime Error cannot be recovered from
if isinstance(err, RuntimeError):
exit(1)
self.shutdown_callback()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -8,39 +8,32 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
from .kv_cache import KVCache, get_kv_scales
__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",

View File

@ -1,16 +1,12 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
if ATTENTION in {"flashinfer", "flashdecoding"}:
@dataclass
class Seqlen:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
max_q: int
@ -19,13 +15,13 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def __init__(
self,
input_lengths,
prefix_lengths,
cache_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
@ -43,7 +39,7 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.prefix_lengths
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
@ -54,19 +50,3 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
else:
@dataclass
class Seqlen:
input_lengths: torch.Tensor
prefix_lengths: torch.Tensor
cu_seqlen_q: torch.Tensor
max_q: int
max_k: int
def clamp(self, max):
if SYSTEM == "rocm":
return self
self.input_lengths = torch.clamp(self.input_lengths, max=max)
return self

View File

@ -1,4 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
@ -7,44 +8,22 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -70,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
can_scale = kv_cache.can_scale(kv_scales)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
@ -79,10 +60,13 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
elif ATTENTION == "flashdecoding":
max_q = 1
@ -98,8 +82,8 @@ def paged_attention(
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
@ -123,7 +107,7 @@ def paged_attention(
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
from vllm._C import ops
out = torch.empty_like(query)
@ -135,8 +119,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -168,8 +152,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -216,60 +200,69 @@ except ImportError:
) from e
if ATTENTION == "flashdecoding" and not V2:
raise ValueError("Flash decoding requires Flash Attention V2")
SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
can_scale = kv_cache.can_scale(kv_scales)
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
if softcap is None:
softcap = 0.0
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
elif V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
# If we are using flashdecoding or paged, we always use flash-attn for
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if softcap is None:
softcap = 0.0
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
# flashdecoding: pass the KV caches, paged: pass the KV.
kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
block_tables if ATTENTION == "flashdecoding" else None,
None,
seqlen.max_q,
seqlen.max_k,
@ -284,57 +277,45 @@ elif V2:
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
else:
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
raise NotImplementedError("softcap is not available in flash attn v1")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
if key.shape[1] != query.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
if key.shape[1] == 1:
key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
original_shape = key.shape
key = (
key.unsqueeze(2)
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
if value.shape[1] != query.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
if value.shape[1] == 1:
value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
original_shape = value.shape
value = (
value.unsqueeze(2)
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
out = torch.empty_like(query)
flash_attn_cuda.fwd(
q,
k,
v,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -351,15 +332,8 @@ else:
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -699,7 +699,6 @@ def check_args(
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,

View File

@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype,
window_left: int,
):
@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
data_type=kv_cache_dtype,
q_data_type=dtype,
window_left=window_left,
)

View File

@ -1,31 +1,37 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -42,39 +48,32 @@ def attention(
return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
input_lengths,
BLOCK_SIZE,
max_s,
None,
@ -83,9 +82,7 @@ def paged_attention(
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -1,9 +1,38 @@
from typing import Tuple
from dataclasses import dataclass, field
from loguru import logger
import torch
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
@ -24,11 +53,11 @@ class KVCache:
):
"""Construct the key-value cache for a layer."""
if dtype == torch.float8_e5m2 and (
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
ATTENTION != "flashinfer" or SYSTEM != "cuda"
):
raise ValueError(
"float8_e5m2 KV cache is currently only supported for flashinfer on CUDA"
"FP8 KV cache is currently only supported for flashinfer on CUDA"
)
element_size = torch.tensor([], dtype=dtype).element_size()
@ -77,6 +106,33 @@ class KVCache:
),
)
def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif (
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
)
return False
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
@ -95,18 +151,34 @@ class KVCache:
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype == torch.float8_e5m2:
# Torch index_put does not support float8_e5m2 yet, so
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
value_cache = value_cache.view(torch.uint8)
@ -116,4 +188,59 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
reshape_and_cache(key, value, key_cache, value_cache, slots)
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "cuda":
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,8 +1,8 @@
import os
from typing import Optional
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
@ -16,8 +16,6 @@ _PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
@ -29,38 +27,17 @@ except ImportError as e:
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
@ -84,10 +61,10 @@ def paged_attention(
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
block_size = kv_cache.value.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
@ -104,7 +81,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query)
@ -124,8 +101,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -158,8 +135,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -177,8 +154,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
num_kv_heads,
softmax_scale,
block_tables,
@ -227,29 +204,36 @@ if ENGINE != "triton":
SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
softcap: Optional[float] = None,
):
if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
out = torch.empty_like(query)
if softcap is None:
softcap = 0.0
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -270,30 +254,19 @@ if ENGINE == "ck":
None,
)[0]
elif ENGINE == "triton":
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -304,13 +277,12 @@ elif ENGINE == "triton":
)
return output
else:
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -0,0 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
from .ipex import WQLinear
elif SYSTEM == "cuda":
from .cuda import WQLinear
__all__ = ["WQLinear"]

View File

@ -0,0 +1,48 @@
from typing import Optional
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -1,7 +1,7 @@
import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from typing import Optional, Tuple, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
@ -26,6 +26,12 @@ def is_fbgemm_gpu_available():
return False
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
@ -51,27 +57,82 @@ def get_fp8_linear() -> torch.nn.Module:
return Fp8Linear
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
def fp8_quantize(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
if FBGEMM_DYN_AVAILABLE and not scalar:
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale
if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
)
return qweight.reshape(shape), scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)
if scale is None:
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
scale = scale.float().reciprocal()
else:
# Use reciprocal to avoid more expensive division.
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
return qweight, scale
@ -92,9 +153,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -125,9 +194,24 @@ class HybridFP8UnquantLoader(WeightsLoader):
)
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -154,9 +238,22 @@ class HybridFP8UnquantLoader(WeightsLoader):
]
scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -174,9 +271,16 @@ class HybridFP8UnquantLoader(WeightsLoader):
.reshape(-1)
.expand(w.shape[0])
)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
).reshape(-1)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
@ -191,6 +295,7 @@ class Fp8Weight(Weight):
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None
def get_linear(self, bias: torch.Tensor):
@ -200,26 +305,41 @@ class Fp8Weight(Weight):
# memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
)
class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__(
self,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
qweight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
)
self.dtype = dtype
self.qweight = qweight
self.scale = scale
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None
if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
@ -227,6 +347,8 @@ class Fp8Linear(torch.nn.Module):
if scale_upper_bound is not None
else None
)
else:
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None
@ -234,22 +356,46 @@ class Fp8Linear(torch.nn.Module):
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
)
@classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
def from_fp8(
cls,
weight: torch.Tensor,
scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
scale_upper_bound=input_scale,
input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias,
dtype=dtype,
)
@classmethod
def get_shared_device_identity(cls, device):
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
qinput, scale = fp8_quantize(
@ -266,8 +412,18 @@ class Fp8Linear(torch.nn.Module):
)
return y.to(self.dtype)
qinput, scale = fp8_quantize(input, scalar=True)
output, _ = torch._scaled_mm(
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
per_tensor_weights = self.scale.numel() == 1
per_tensor_activations = scale.numel() == 1
if SYSTEM != "rocm" or (per_tensor_weights and per_tensor_activations):
output = torch._scaled_mm(
qinput,
self.qweight.t(),
out_dtype=self.dtype,
@ -275,6 +431,30 @@ class Fp8Linear(torch.nn.Module):
scale_b=self.scale,
bias=self.bias,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
else:
device_identity = None
if SYSTEM == "rocm":
device_identity = self.get_shared_device_identity(self.qweight.device)
output = torch._scaled_mm(
qinput,
self.qweight.t(),
scale_a=device_identity,
scale_b=device_identity,
out_dtype=torch.float32,
)
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
output = output * scale * self.scale.t()
if self.bias is not None:
output = output + self.bias
output = output.to(dtype=self.dtype)
return output

View File

@ -8,6 +8,11 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
if SYSTEM == "ipex":
from .ipex import QuantLinear
elif SYSTEM in {"cuda", "rocm"}:
from .triton import QuantLinear
@dataclass
class GPTQWeight(Weight):
@ -36,7 +41,7 @@ class GPTQWeight(Weight):
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
from text_generation_server.layers.awq.quantize import WQLinear
return WQLinear(
w_bit=self.bits,
@ -60,8 +65,6 @@ class GPTQWeight(Weight):
return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear(
self.qweight,
self.qzeros,
@ -298,6 +301,7 @@ class GPTQWeightsLoader(WeightsLoader):
self._get_gptq_params(weights)
use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False
@ -321,7 +325,8 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
@ -332,6 +337,7 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
@ -350,16 +356,16 @@ class GPTQWeightsLoader(WeightsLoader):
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
@ -392,7 +398,7 @@ class GPTQWeightsLoader(WeightsLoader):
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
@ -400,7 +406,7 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
if weights.has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"

View File

@ -0,0 +1,126 @@
import math
import numpy as np
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.infeatures,
self.outfeatures,
bias=self.bias,
group_size=self.groupsize,
g_idx=g_idx,
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

Some files were not shown because too many files have changed in this diff Show More